""" agent_variant.py 第四种实现方式: * 自动收集所有 @tool 装饰的函数 * 内置简单内存缓存,避免同问多答 * 将“相似问题检索”与“工具调用”全部拆成独立节点 * 采用工厂模式生成 LLM,支持 google / groq / huggingface 三种 provider """ from __future__ import annotations import os from typing import Dict, List, Any from dotenv import load_dotenv from langgraph.graph import StateGraph, START, END, MessagesState from langgraph.prebuilt import ToolNode, tools_condition from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.tools import tool from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.vectorstores import SupabaseVectorStore from supabase.client import create_client load_dotenv() # -------------------------------------------------------------------------------------- # 1. 工具定义(算术 & 检索) # -------------------------------------------------------------------------------------- @tool def multiply(x: int, y: int) -> int: """Return x * y.""" return x * y @tool def plus(x: int, y: int) -> int: """Return x + y.""" return x + y @tool def minus(x: int, y: int) -> int: """Return x - y.""" return x - y @tool def divide(x: int, y: int) -> float: """Return x / y; raise if y==0.""" if y == 0: raise ValueError("Cannot divide by zero.") return x / y @tool def modulo(x: int, y: int) -> int: """Return x % y.""" return x % y @tool def wiki(query: str) -> str: """Top-2 Wikipedia docs for *query*.""" docs = WikipediaLoader(query=query, load_max_docs=2).load() return "\n\n---\n\n".join(d.page_content for d in docs) @tool def tavily(query: str) -> str: """Top-3 Tavily results for *query*.""" docs = TavilySearchResults(max_results=3).invoke(query=query) return "\n\n---\n\n".join(d.page_content for d in docs) @tool def arxiv(query: str) -> str: """Top-3 arXiv abstracts for *query*.""" docs = ArxivLoader(query=query, load_max_docs=3).load() return "\n\n---\n\n".join(d.page_content[:1000] for d in docs) # 自动收集全部工具 TOOLS = [obj for obj in globals().values() if callable(obj) and getattr(obj, "__tool", False)] # -------------------------------------------------------------------------------------- # 2. 简易内存级缓存,避免重复计算 # -------------------------------------------------------------------------------------- _CACHE: Dict[str, AIMessage] = {} def cache_lookup(state: MessagesState) -> Dict[str, List[Any]]: question = state["messages"][-1].content if question in _CACHE: # 命中缓存,直接返回 return {"messages": [_CACHE[question]]} return state # miss 时原样透传 # -------------------------------------------------------------------------------------- # 3. 向量检索(Supabase)作为“相似问题提示” # -------------------------------------------------------------------------------------- def build_vector_store() -> SupabaseVectorStore: embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") client = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"]) return SupabaseVectorStore( client=client, embedding=embeddings, table_name="documents", query_name="match_documents_langchain", ) VECTOR_STORE = build_vector_store() def similar_q(state: MessagesState): query = state["messages"][-1].content hits = VECTOR_STORE.similarity_search(query, k=1) if hits: hint = AIMessage(content=f"(参考答案)\n{hits[0].page_content}") return {"messages": [hint]} return state # -------------------------------------------------------------------------------------- # 4. LLM 工厂 # -------------------------------------------------------------------------------------- def make_llm(provider: str): if provider == "google": return ChatGoogleGenerativeAI(model="gemini-2.0-pro", temperature=0) if provider == "groq": return ChatGroq(model="llama3-70b-8192", temperature=0) if provider == "huggingface": return ChatHuggingFace( llm=HuggingFaceEndpoint( url="/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2Fmicrosoft%2FPhi-3-mini-4k-instruct", temperature=0 ) ) raise ValueError(f"Unknown provider: {provider}") # -------------------------------------------------------------------------------------- # 5. 加载系统提示 # -------------------------------------------------------------------------------------- with open("system_prompt.txt", encoding="utf-8") as fp: SYSTEM_PROMPT = SystemMessage(content=fp.read()) # -------------------------------------------------------------------------------------- # 6. 构建 LangGraph # -------------------------------------------------------------------------------------- def build_graph(provider: str = "groq"): llm = make_llm(provider).bind_tools(TOOLS) def assistant(state: MessagesState): return {"messages": [llm.invoke(state["messages"])]} builder = StateGraph(MessagesState) # 节点 builder.add_node("cache", cache_lookup) builder.add_node("retriever", similar_q) builder.add_node("assistant", assistant) builder.add_node("toolcaller", ToolNode(TOOLS)) # 边 builder.add_edge(START, "cache") builder.add_edge("cache", "retriever") builder.add_edge("retriever", "assistant") # 工具调用与结束判断 builder.add_conditional_edges( "assistant", tools_condition, { "toolcaller": "toolcaller", END: END } ) builder.add_edge("toolcaller", "assistant") return builder.compile() # -------------------------------------------------------------------------------------- # 7. CLI 用法示例 # -------------------------------------------------------------------------------------- if __name__ == "__main__": import sys question = " ".join(sys.argv[1:]) or "Who won the Nobel physics prize in 1921?" graph = build_graph("groq") init_msgs = [SYSTEM_PROMPT, HumanMessage(content=question)] result = graph.invoke({"messages": init_msgs}) for msg in result["messages"]: msg.pretty_print() # 写入缓存 _CACHE[question] = result["messages"][-1]