LocalMate / app /agent /state.py
Cuong2004's picture
fix api response
ac07893
"""Agent State Models - State management for ReAct multi-step reasoning."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class ReActStep:
"""A single step in the ReAct reasoning loop."""
step_number: int
thought: str # LLM's reasoning about what to do
action: str # Tool name or "finish"
action_input: dict # Tool arguments
observation: Any = None # Tool result
duration_ms: float = 0
timestamp: datetime = field(default_factory=datetime.now)
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {
"step": self.step_number,
"thought": self.thought,
"action": self.action,
"action_input": self.action_input,
"observation": self._truncate_observation(),
"duration_ms": round(self.duration_ms, 1),
}
def _truncate_observation(self, max_items: int = 3) -> Any:
"""Truncate observation for display."""
if isinstance(self.observation, list) and len(self.observation) > max_items:
return self.observation[:max_items] + [f"... and {len(self.observation) - max_items} more"]
return self.observation
@dataclass
class AgentState:
"""Complete state for a ReAct agent session."""
query: str
steps: list[ReActStep] = field(default_factory=list)
context: dict = field(default_factory=dict) # Accumulated context from tools
current_step: int = 0
max_steps: int = 5
is_complete: bool = False
final_answer: str = ""
selected_place_ids: list[str] = field(default_factory=list) # LLM-selected places
total_duration_ms: float = 0
error: str | None = None
def add_step(self, step: ReActStep) -> None:
"""Add a completed step to the state."""
self.steps.append(step)
self.current_step += 1
# Add tool result to context
if step.action != "finish" and step.observation:
self.context[step.action] = step.observation
def can_continue(self) -> bool:
"""Check if agent can continue reasoning."""
return (
not self.is_complete
and self.current_step < self.max_steps
and self.error is None
)
def get_context_summary(self) -> str:
"""Get a summary of accumulated context for LLM."""
if not self.context:
return "Chưa có kết quả từ các tools trước đó."
summary_parts = []
for tool_name, result in self.context.items():
if isinstance(result, list):
summary_parts.append(f"- {tool_name}: {len(result)} kết quả")
elif isinstance(result, dict):
summary_parts.append(f"- {tool_name}: {result}")
else:
summary_parts.append(f"- {tool_name}: {str(result)[:100]}")
return "Kết quả từ các bước trước:\n" + "\n".join(summary_parts)
def to_dict(self) -> dict:
"""Convert to dictionary for API response."""
return {
"query": self.query,
"total_steps": len(self.steps),
"max_steps": self.max_steps,
"is_complete": self.is_complete,
"steps": [s.to_dict() for s in self.steps],
"tools_used": list(self.context.keys()),
"total_duration_ms": round(self.total_duration_ms, 1),
"error": self.error,
}