CelestialAnthem
first try
4f895cd
"""
agent_variant.py
第四种实现方式:
* 自动收集所有 @tool 装饰的函数
* 内置简单内存缓存,避免同问多答
* 将“相似问题检索”与“工具调用”全部拆成独立节点
* 采用工厂模式生成 LLM,支持 google / groq / huggingface 三种 provider
"""
from __future__ import annotations
import os
from typing import Dict, List, Any
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import SupabaseVectorStore
from supabase.client import create_client
load_dotenv()
# --------------------------------------------------------------------------------------
# 1. 工具定义(算术 & 检索)
# --------------------------------------------------------------------------------------
@tool
def multiply(x: int, y: int) -> int:
"""Return x * y."""
return x * y
@tool
def plus(x: int, y: int) -> int:
"""Return x + y."""
return x + y
@tool
def minus(x: int, y: int) -> int:
"""Return x - y."""
return x - y
@tool
def divide(x: int, y: int) -> float:
"""Return x / y; raise if y==0."""
if y == 0:
raise ValueError("Cannot divide by zero.")
return x / y
@tool
def modulo(x: int, y: int) -> int:
"""Return x % y."""
return x % y
@tool
def wiki(query: str) -> str:
"""Top-2 Wikipedia docs for *query*."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(d.page_content for d in docs)
@tool
def tavily(query: str) -> str:
"""Top-3 Tavily results for *query*."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
return "\n\n---\n\n".join(d.page_content for d in docs)
@tool
def arxiv(query: str) -> str:
"""Top-3 arXiv abstracts for *query*."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(d.page_content[:1000] for d in docs)
# 自动收集全部工具
TOOLS = [obj for obj in globals().values() if callable(obj) and getattr(obj, "__tool", False)]
# --------------------------------------------------------------------------------------
# 2. 简易内存级缓存,避免重复计算
# --------------------------------------------------------------------------------------
_CACHE: Dict[str, AIMessage] = {}
def cache_lookup(state: MessagesState) -> Dict[str, List[Any]]:
question = state["messages"][-1].content
if question in _CACHE: # 命中缓存,直接返回
return {"messages": [_CACHE[question]]}
return state # miss 时原样透传
# --------------------------------------------------------------------------------------
# 3. 向量检索(Supabase)作为“相似问题提示”
# --------------------------------------------------------------------------------------
def build_vector_store() -> SupabaseVectorStore:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
client = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"])
return SupabaseVectorStore(
client=client,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
VECTOR_STORE = build_vector_store()
def similar_q(state: MessagesState):
query = state["messages"][-1].content
hits = VECTOR_STORE.similarity_search(query, k=1)
if hits:
hint = AIMessage(content=f"(参考答案)\n{hits[0].page_content}")
return {"messages": [hint]}
return state
# --------------------------------------------------------------------------------------
# 4. LLM 工厂
# --------------------------------------------------------------------------------------
def make_llm(provider: str):
if provider == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-pro", temperature=0)
if provider == "groq":
return ChatGroq(model="llama3-70b-8192", temperature=0)
if provider == "huggingface":
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2Fmicrosoft%2FPhi-3-mini-4k-instruct%26quot%3B%3C%2Fspan%3E%2C%3C!-- HTML_TAG_END -->
temperature=0
)
)
raise ValueError(f"Unknown provider: {provider}")
# --------------------------------------------------------------------------------------
# 5. 加载系统提示
# --------------------------------------------------------------------------------------
with open("system_prompt.txt", encoding="utf-8") as fp:
SYSTEM_PROMPT = SystemMessage(content=fp.read())
# --------------------------------------------------------------------------------------
# 6. 构建 LangGraph
# --------------------------------------------------------------------------------------
def build_graph(provider: str = "groq"):
llm = make_llm(provider).bind_tools(TOOLS)
def assistant(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
builder = StateGraph(MessagesState)
# 节点
builder.add_node("cache", cache_lookup)
builder.add_node("retriever", similar_q)
builder.add_node("assistant", assistant)
builder.add_node("toolcaller", ToolNode(TOOLS))
# 边
builder.add_edge(START, "cache")
builder.add_edge("cache", "retriever")
builder.add_edge("retriever", "assistant")
# 工具调用与结束判断
builder.add_conditional_edges(
"assistant",
tools_condition,
{
"toolcaller": "toolcaller",
END: END
}
)
builder.add_edge("toolcaller", "assistant")
return builder.compile()
# --------------------------------------------------------------------------------------
# 7. CLI 用法示例
# --------------------------------------------------------------------------------------
if __name__ == "__main__":
import sys
question = " ".join(sys.argv[1:]) or "Who won the Nobel physics prize in 1921?"
graph = build_graph("groq")
init_msgs = [SYSTEM_PROMPT, HumanMessage(content=question)]
result = graph.invoke({"messages": init_msgs})
for msg in result["messages"]:
msg.pretty_print()
# 写入缓存
_CACHE[question] = result["messages"][-1]