|
|
"""API Router with /chat endpoint for Swagger testing.""" |
|
|
|
|
|
from enum import Enum |
|
|
from pydantic import BaseModel, Field |
|
|
from fastapi import APIRouter, Depends, UploadFile, File, Query, HTTPException |
|
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
|
|
|
|
from app.agent.mmca_agent import MMCAAgent |
|
|
from app.agent.react_agent import ReActAgent |
|
|
from app.shared.db.session import get_db |
|
|
from app.core.config import settings |
|
|
from app.mcp.tools import mcp_tools |
|
|
from app.shared.chat_history import chat_history |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
class LLMProvider(str, Enum): |
|
|
"""Available LLM providers.""" |
|
|
|
|
|
GOOGLE = "Google" |
|
|
MEGALLM = "MegaLLM" |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
"""Chat request model.""" |
|
|
|
|
|
message: str = Field( |
|
|
..., |
|
|
description="User message in natural language", |
|
|
examples=["Tìm quán cafe gần bãi biển Mỹ Khê"], |
|
|
) |
|
|
user_id: str = Field( |
|
|
default="anonymous", |
|
|
description="User ID for session management", |
|
|
examples=["user_123", "anonymous"], |
|
|
) |
|
|
session_id: str | None = Field( |
|
|
None, |
|
|
description="Session ID (optional, uses 'default' if not provided)", |
|
|
examples=["session_abc", "default"], |
|
|
) |
|
|
image_url: str | None = Field( |
|
|
None, |
|
|
description="Optional image URL for visual similarity search", |
|
|
examples=["https://example.com/cafe.jpg"], |
|
|
) |
|
|
provider: LLMProvider = Field( |
|
|
default=LLMProvider.MEGALLM, |
|
|
description="LLM provider to use: Google or MegaLLM", |
|
|
) |
|
|
model: str | None = Field( |
|
|
None, |
|
|
description=f"Model name. Defaults: Google={settings.default_gemini_model}, MegaLLM={settings.default_megallm_model}", |
|
|
examples=["gemini-2.0-flash", "deepseek-r1-distill-llama-70b"], |
|
|
) |
|
|
react_mode: bool = Field( |
|
|
default=False, |
|
|
description="Enable ReAct multi-step reasoning mode", |
|
|
) |
|
|
max_steps: int = Field( |
|
|
default=5, |
|
|
description="Maximum reasoning steps for ReAct mode", |
|
|
ge=1, |
|
|
le=10, |
|
|
) |
|
|
|
|
|
|
|
|
class WorkflowStepResponse(BaseModel): |
|
|
"""Workflow step info.""" |
|
|
|
|
|
step: str = Field(..., description="Step name") |
|
|
tool: str | None = Field(None, description="Tool used") |
|
|
purpose: str = Field(default="", description="Purpose of this step") |
|
|
results: int = Field(default=0, description="Number of results") |
|
|
|
|
|
|
|
|
class WorkflowResponse(BaseModel): |
|
|
"""Workflow trace for debugging.""" |
|
|
|
|
|
query: str = Field(..., description="Original query") |
|
|
intent_detected: str = Field(..., description="Detected intent") |
|
|
tools_used: list[str] = Field(default_factory=list, description="Tools used") |
|
|
steps: list[WorkflowStepResponse] = Field(default_factory=list, description="Workflow steps") |
|
|
total_duration_ms: float = Field(..., description="Total processing time") |
|
|
|
|
|
|
|
|
class PlaceItem(BaseModel): |
|
|
"""Place item for FE rendering.""" |
|
|
place_id: str |
|
|
name: str |
|
|
category: str | None = None |
|
|
lat: float | None = None |
|
|
lng: float | None = None |
|
|
rating: float | None = None |
|
|
distance_km: float | None = None |
|
|
address: str | None = None |
|
|
image_url: str | None = None |
|
|
|
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
"""Chat response model.""" |
|
|
|
|
|
response: str = Field(..., description="Agent's response") |
|
|
status: str = Field(default="success", description="Response status") |
|
|
provider: str = Field(..., description="LLM provider used") |
|
|
model: str = Field(..., description="Model used") |
|
|
user_id: str = Field(..., description="User ID") |
|
|
session_id: str = Field(..., description="Session ID used") |
|
|
places: list[PlaceItem] = Field(default_factory=list, description="LLM-selected places for FE rendering") |
|
|
tools_used: list[str] = Field(default_factory=list, description="MCP tools used") |
|
|
duration_ms: float = Field(default=0, description="Total processing time in ms") |
|
|
|
|
|
|
|
|
class NearbyRequest(BaseModel): |
|
|
"""Nearby places request model.""" |
|
|
|
|
|
lat: float = Field(..., description="Latitude", examples=[16.0626442]) |
|
|
lng: float = Field(..., description="Longitude", examples=[108.2462143]) |
|
|
max_distance_km: float = Field( |
|
|
default=5.0, |
|
|
description="Maximum distance in kilometers", |
|
|
examples=[5.0, 18.72], |
|
|
) |
|
|
category: str | None = Field( |
|
|
None, |
|
|
description="Category filter (cafe, restaurant, attraction, etc.)", |
|
|
examples=["cafe", "restaurant"], |
|
|
) |
|
|
limit: int = Field(default=10, description="Maximum results", examples=[10, 20]) |
|
|
|
|
|
|
|
|
class PlaceResponse(BaseModel): |
|
|
"""Place response model.""" |
|
|
|
|
|
place_id: str |
|
|
name: str |
|
|
category: str | None = None |
|
|
lat: float | None = None |
|
|
lng: float | None = None |
|
|
distance_km: float | None = None |
|
|
rating: float | None = None |
|
|
description: str | None = None |
|
|
|
|
|
|
|
|
class NearbyResponse(BaseModel): |
|
|
"""Nearby places response model.""" |
|
|
|
|
|
places: list[PlaceResponse] |
|
|
count: int |
|
|
query: dict |
|
|
|
|
|
|
|
|
class ClearHistoryRequest(BaseModel): |
|
|
"""Clear history request model.""" |
|
|
|
|
|
user_id: str = Field(..., description="User ID to clear history for") |
|
|
session_id: str | None = Field( |
|
|
None, |
|
|
description="Session ID to clear (clears all if not provided)", |
|
|
) |
|
|
|
|
|
|
|
|
class HistoryResponse(BaseModel): |
|
|
"""Chat history response model.""" |
|
|
|
|
|
user_id: str |
|
|
sessions: list[str] |
|
|
current_session: str | None |
|
|
message_count: int |
|
|
|
|
|
|
|
|
class MessageItem(BaseModel): |
|
|
"""Single chat message.""" |
|
|
role: str |
|
|
content: str |
|
|
timestamp: str |
|
|
|
|
|
|
|
|
class MessagesResponse(BaseModel): |
|
|
"""Chat messages response.""" |
|
|
user_id: str |
|
|
session_id: str |
|
|
messages: list[MessageItem] |
|
|
count: int |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/nearby", |
|
|
response_model=NearbyResponse, |
|
|
summary="Find nearby places (Neo4j)", |
|
|
description=""" |
|
|
Find places near a given location using Neo4j spatial query. |
|
|
|
|
|
This endpoint directly tests the `find_nearby_places` MCP tool. |
|
|
|
|
|
## Test Cases |
|
|
- Case 1: lat=16.0626442, lng=108.2462143, max_distance_km=18.72 |
|
|
- Case 2: lat=16.0623184, lng=108.2306049, max_distance_km=17.94 |
|
|
""", |
|
|
) |
|
|
async def find_nearby(request: NearbyRequest) -> NearbyResponse: |
|
|
""" |
|
|
Find nearby places using Neo4j graph database. |
|
|
|
|
|
Directly calls the find_nearby_places MCP tool. |
|
|
""" |
|
|
places = await mcp_tools.find_nearby_places( |
|
|
lat=request.lat, |
|
|
lng=request.lng, |
|
|
max_distance_km=request.max_distance_km, |
|
|
category=request.category, |
|
|
limit=request.limit, |
|
|
) |
|
|
|
|
|
return NearbyResponse( |
|
|
places=[ |
|
|
PlaceResponse( |
|
|
place_id=p.place_id, |
|
|
name=p.name, |
|
|
category=p.category, |
|
|
lat=p.lat, |
|
|
lng=p.lng, |
|
|
distance_km=p.distance_km, |
|
|
rating=p.rating, |
|
|
description=p.description, |
|
|
) |
|
|
for p in places |
|
|
], |
|
|
count=len(places), |
|
|
query={ |
|
|
"lat": request.lat, |
|
|
"lng": request.lng, |
|
|
"max_distance_km": request.max_distance_km, |
|
|
"category": request.category, |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
async def enrich_places_from_ids(place_ids: list[str], db: AsyncSession) -> list[PlaceItem]: |
|
|
""" |
|
|
Enrich LLM-selected place_ids with full details from DB. |
|
|
|
|
|
Args: |
|
|
place_ids: List of place_ids selected by LLM in synthesis |
|
|
db: Database session |
|
|
|
|
|
Returns: |
|
|
List of PlaceItem with full details |
|
|
""" |
|
|
if not place_ids: |
|
|
return [] |
|
|
|
|
|
|
|
|
from sqlalchemy import text |
|
|
result = await db.execute( |
|
|
text(""" |
|
|
SELECT place_id, name, category, address, rating, |
|
|
ST_X(coordinates::geometry) as lng, |
|
|
ST_Y(coordinates::geometry) as lat |
|
|
FROM places_metadata |
|
|
WHERE place_id = ANY(:place_ids) |
|
|
"""), |
|
|
{"place_ids": place_ids} |
|
|
) |
|
|
rows = result.fetchall() |
|
|
|
|
|
|
|
|
places_dict = {row.place_id: row for row in rows} |
|
|
places = [] |
|
|
for pid in place_ids: |
|
|
if pid in places_dict: |
|
|
row = places_dict[pid] |
|
|
places.append(PlaceItem( |
|
|
place_id=row.place_id, |
|
|
name=row.name, |
|
|
category=row.category, |
|
|
lat=row.lat, |
|
|
lng=row.lng, |
|
|
rating=float(row.rating) if row.rating else None, |
|
|
address=row.address, |
|
|
)) |
|
|
|
|
|
return places |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/chat", |
|
|
response_model=ChatResponse, |
|
|
summary="Chat with LocalMate Agent", |
|
|
description=""" |
|
|
Chat with the Multi-Modal Contextual Agent (MMCA). |
|
|
|
|
|
## Session Management |
|
|
- Each user can have up to **3 sessions** stored |
|
|
- Provide `user_id` and optional `session_id` to maintain conversation history |
|
|
- History is automatically injected into the agent prompt |
|
|
|
|
|
## LLM Providers |
|
|
- **Google**: Gemini models (gemini-2.0-flash, etc.) |
|
|
- **MegaLLM**: DeepSeek models (deepseek-r1-distill-llama-70b, etc.) |
|
|
|
|
|
## Examples |
|
|
- "Tìm quán cafe gần bãi biển Mỹ Khê" |
|
|
- "Nhà hàng hải sản nào gần Cầu Rồng?" |
|
|
""", |
|
|
) |
|
|
async def chat( |
|
|
request: ChatRequest, |
|
|
db: AsyncSession = Depends(get_db), |
|
|
) -> ChatResponse: |
|
|
""" |
|
|
Chat endpoint with per-user history support. |
|
|
|
|
|
Send a natural language message, select provider and model. |
|
|
The agent will analyze your intent, query relevant data sources, |
|
|
and return a synthesized response with conversation context. |
|
|
""" |
|
|
|
|
|
if request.model: |
|
|
model = request.model |
|
|
elif request.provider == LLMProvider.GOOGLE: |
|
|
model = settings.default_gemini_model |
|
|
else: |
|
|
model = settings.default_megallm_model |
|
|
|
|
|
|
|
|
session_id = request.session_id or "default" |
|
|
|
|
|
|
|
|
history = chat_history.get_history( |
|
|
user_id=request.user_id, |
|
|
session_id=session_id, |
|
|
max_messages=6, |
|
|
) |
|
|
|
|
|
|
|
|
chat_history.add_message( |
|
|
user_id=request.user_id, |
|
|
role="user", |
|
|
content=request.message, |
|
|
session_id=session_id, |
|
|
) |
|
|
|
|
|
|
|
|
if request.react_mode: |
|
|
|
|
|
agent = ReActAgent( |
|
|
provider=request.provider.value, |
|
|
model=model, |
|
|
max_steps=request.max_steps, |
|
|
) |
|
|
response_text, agent_state = await agent.run( |
|
|
query=request.message, |
|
|
db=db, |
|
|
image_url=request.image_url, |
|
|
history=history, |
|
|
) |
|
|
|
|
|
|
|
|
workflow = agent.to_workflow(agent_state) |
|
|
workflow_data = workflow.to_dict() |
|
|
|
|
|
|
|
|
chat_history.add_message( |
|
|
user_id=request.user_id, |
|
|
role="assistant", |
|
|
content=response_text, |
|
|
session_id=session_id, |
|
|
) |
|
|
|
|
|
|
|
|
places = await enrich_places_from_ids(agent_state.selected_place_ids, db) |
|
|
|
|
|
return ChatResponse( |
|
|
response=response_text, |
|
|
status="success", |
|
|
provider=request.provider.value, |
|
|
model=model, |
|
|
user_id=request.user_id, |
|
|
session_id=session_id, |
|
|
places=places, |
|
|
tools_used=list(agent_state.context.keys()), |
|
|
duration_ms=agent_state.total_duration_ms, |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
agent = MMCAAgent(provider=request.provider.value, model=model) |
|
|
|
|
|
|
|
|
result = await agent.chat( |
|
|
message=request.message, |
|
|
db=db, |
|
|
image_url=request.image_url, |
|
|
history=history, |
|
|
) |
|
|
|
|
|
|
|
|
chat_history.add_message( |
|
|
user_id=request.user_id, |
|
|
role="assistant", |
|
|
content=result.response, |
|
|
session_id=session_id, |
|
|
) |
|
|
|
|
|
|
|
|
places = [] |
|
|
if result.selected_place_ids: |
|
|
places = await enrich_places_from_ids(result.selected_place_ids, db) |
|
|
|
|
|
distance_map = {} |
|
|
for tool_call in result.tool_results: |
|
|
if tool_call.result: |
|
|
for item in tool_call.result: |
|
|
if isinstance(item, dict) and 'place_id' in item and 'distance_km' in item: |
|
|
distance_map[item['place_id']] = item['distance_km'] |
|
|
for place in places: |
|
|
if place.place_id in distance_map: |
|
|
place.distance_km = distance_map[place.place_id] |
|
|
|
|
|
return ChatResponse( |
|
|
response=result.response, |
|
|
status="success", |
|
|
provider=request.provider.value, |
|
|
model=model, |
|
|
user_id=request.user_id, |
|
|
session_id=session_id, |
|
|
places=places, |
|
|
tools_used=result.tools_used, |
|
|
duration_ms=result.total_duration_ms, |
|
|
) |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/chat/clear", |
|
|
summary="Clear chat history", |
|
|
description="Clears the conversation history for a specific user/session.", |
|
|
) |
|
|
async def clear_chat(request: ClearHistoryRequest): |
|
|
"""Clear conversation history for a user.""" |
|
|
if request.session_id: |
|
|
chat_history.clear_session(request.user_id, request.session_id) |
|
|
message = f"Session '{request.session_id}' cleared for user '{request.user_id}'" |
|
|
else: |
|
|
chat_history.clear_all_sessions(request.user_id) |
|
|
message = f"All sessions cleared for user '{request.user_id}'" |
|
|
|
|
|
return {"status": "success", "message": message} |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/chat/history/{user_id}", |
|
|
response_model=HistoryResponse, |
|
|
summary="Get chat history info", |
|
|
description="Get information about user's chat sessions.", |
|
|
) |
|
|
async def get_history_info(user_id: str) -> HistoryResponse: |
|
|
"""Get chat history information for a user.""" |
|
|
sessions = chat_history.get_session_ids(user_id) |
|
|
messages = chat_history.get_messages(user_id) |
|
|
|
|
|
return HistoryResponse( |
|
|
user_id=user_id, |
|
|
sessions=sessions, |
|
|
current_session="default" if "default" in sessions else (sessions[0] if sessions else None), |
|
|
message_count=len(messages), |
|
|
) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/chat/messages/{user_id}", |
|
|
response_model=MessagesResponse, |
|
|
summary="Get chat messages", |
|
|
description="Get actual chat messages from a specific session.", |
|
|
) |
|
|
async def get_chat_messages( |
|
|
user_id: str, |
|
|
session_id: str = "default", |
|
|
) -> MessagesResponse: |
|
|
"""Get chat messages for a session.""" |
|
|
messages = chat_history.get_messages(user_id, session_id) |
|
|
|
|
|
return MessagesResponse( |
|
|
user_id=user_id, |
|
|
session_id=session_id, |
|
|
messages=[ |
|
|
MessageItem( |
|
|
role=m.role, |
|
|
content=m.content, |
|
|
timestamp=m.timestamp.isoformat(), |
|
|
) |
|
|
for m in messages |
|
|
], |
|
|
count=len(messages), |
|
|
) |
|
|
|
|
|
|
|
|
class ImageSearchResult(BaseModel): |
|
|
"""Image search result model.""" |
|
|
|
|
|
place_id: str |
|
|
name: str |
|
|
category: str | None = None |
|
|
rating: float | None = None |
|
|
similarity: float |
|
|
matched_images: int = 1 |
|
|
image_url: str | None = None |
|
|
|
|
|
|
|
|
class ImageSearchResponse(BaseModel): |
|
|
"""Image search response model.""" |
|
|
|
|
|
results: list[ImageSearchResult] |
|
|
total: int |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/search/image", |
|
|
response_model=ImageSearchResponse, |
|
|
summary="Search places by image", |
|
|
description=""" |
|
|
Upload an image to find visually similar places. |
|
|
|
|
|
Uses image embeddings stored in Supabase pgvector. |
|
|
""", |
|
|
) |
|
|
async def search_by_image( |
|
|
image: UploadFile = File(..., description="Image file to search"), |
|
|
limit: int = Query(10, ge=1, le=50, description="Maximum results"), |
|
|
db: AsyncSession = Depends(get_db), |
|
|
) -> ImageSearchResponse: |
|
|
""" |
|
|
Search places by uploading an image. |
|
|
|
|
|
Uses visual embeddings to find similar places. |
|
|
""" |
|
|
try: |
|
|
|
|
|
image_bytes = await image.read() |
|
|
|
|
|
if len(image_bytes) > 10 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="Image too large (max 10MB)") |
|
|
|
|
|
|
|
|
results = await mcp_tools.search_by_image_bytes( |
|
|
db=db, |
|
|
image_bytes=image_bytes, |
|
|
limit=limit, |
|
|
) |
|
|
|
|
|
return ImageSearchResponse( |
|
|
results=[ |
|
|
ImageSearchResult( |
|
|
place_id=r.place_id, |
|
|
name=r.name, |
|
|
category=r.category, |
|
|
rating=r.rating, |
|
|
similarity=r.similarity, |
|
|
matched_images=r.matched_images, |
|
|
image_url=r.image_url, |
|
|
) |
|
|
for r in results |
|
|
], |
|
|
total=len(results), |
|
|
) |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Image search error: {str(e)}") |
|
|
|
|
|
|