Spaces:
Runtime error
Runtime error
| """ | |
| agent_variant.py | |
| 第四种实现方式: | |
| * 自动收集所有 @tool 装饰的函数 | |
| * 内置简单内存缓存,避免同问多答 | |
| * 将“相似问题检索”与“工具调用”全部拆成独立节点 | |
| * 采用工厂模式生成 LLM,支持 google / groq / huggingface 三种 provider | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Dict, List, Any | |
| from dotenv import load_dotenv | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_core.tools import tool | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from supabase.client import create_client | |
| load_dotenv() | |
| # -------------------------------------------------------------------------------------- | |
| # 1. 工具定义(算术 & 检索) | |
| # -------------------------------------------------------------------------------------- | |
| def multiply(x: int, y: int) -> int: | |
| """Return x * y.""" | |
| return x * y | |
| def plus(x: int, y: int) -> int: | |
| """Return x + y.""" | |
| return x + y | |
| def minus(x: int, y: int) -> int: | |
| """Return x - y.""" | |
| return x - y | |
| def divide(x: int, y: int) -> float: | |
| """Return x / y; raise if y==0.""" | |
| if y == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return x / y | |
| def modulo(x: int, y: int) -> int: | |
| """Return x % y.""" | |
| return x % y | |
| def wiki(query: str) -> str: | |
| """Top-2 Wikipedia docs for *query*.""" | |
| docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n---\n\n".join(d.page_content for d in docs) | |
| def tavily(query: str) -> str: | |
| """Top-3 Tavily results for *query*.""" | |
| docs = TavilySearchResults(max_results=3).invoke(query=query) | |
| return "\n\n---\n\n".join(d.page_content for d in docs) | |
| def arxiv(query: str) -> str: | |
| """Top-3 arXiv abstracts for *query*.""" | |
| docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| return "\n\n---\n\n".join(d.page_content[:1000] for d in docs) | |
| # 自动收集全部工具 | |
| TOOLS = [obj for obj in globals().values() if callable(obj) and getattr(obj, "__tool", False)] | |
| # -------------------------------------------------------------------------------------- | |
| # 2. 简易内存级缓存,避免重复计算 | |
| # -------------------------------------------------------------------------------------- | |
| _CACHE: Dict[str, AIMessage] = {} | |
| def cache_lookup(state: MessagesState) -> Dict[str, List[Any]]: | |
| question = state["messages"][-1].content | |
| if question in _CACHE: # 命中缓存,直接返回 | |
| return {"messages": [_CACHE[question]]} | |
| return state # miss 时原样透传 | |
| # -------------------------------------------------------------------------------------- | |
| # 3. 向量检索(Supabase)作为“相似问题提示” | |
| # -------------------------------------------------------------------------------------- | |
| def build_vector_store() -> SupabaseVectorStore: | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| client = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"]) | |
| return SupabaseVectorStore( | |
| client=client, | |
| embedding=embeddings, | |
| table_name="documents", | |
| query_name="match_documents_langchain", | |
| ) | |
| VECTOR_STORE = build_vector_store() | |
| def similar_q(state: MessagesState): | |
| query = state["messages"][-1].content | |
| hits = VECTOR_STORE.similarity_search(query, k=1) | |
| if hits: | |
| hint = AIMessage(content=f"(参考答案)\n{hits[0].page_content}") | |
| return {"messages": [hint]} | |
| return state | |
| # -------------------------------------------------------------------------------------- | |
| # 4. LLM 工厂 | |
| # -------------------------------------------------------------------------------------- | |
| def make_llm(provider: str): | |
| if provider == "google": | |
| return ChatGoogleGenerativeAI(model="gemini-2.0-pro", temperature=0) | |
| if provider == "groq": | |
| return ChatGroq(model="llama3-70b-8192", temperature=0) | |
| if provider == "huggingface": | |
| return ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| url="/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2Fmicrosoft%2FPhi-3-mini-4k-instruct%26quot%3B%3C%2Fspan%3E%2C%3C!-- HTML_TAG_END --> | |
| temperature=0 | |
| ) | |
| ) | |
| raise ValueError(f"Unknown provider: {provider}") | |
| # -------------------------------------------------------------------------------------- | |
| # 5. 加载系统提示 | |
| # -------------------------------------------------------------------------------------- | |
| with open("system_prompt.txt", encoding="utf-8") as fp: | |
| SYSTEM_PROMPT = SystemMessage(content=fp.read()) | |
| # -------------------------------------------------------------------------------------- | |
| # 6. 构建 LangGraph | |
| # -------------------------------------------------------------------------------------- | |
| def build_graph(provider: str = "groq"): | |
| llm = make_llm(provider).bind_tools(TOOLS) | |
| def assistant(state: MessagesState): | |
| return {"messages": [llm.invoke(state["messages"])]} | |
| builder = StateGraph(MessagesState) | |
| # 节点 | |
| builder.add_node("cache", cache_lookup) | |
| builder.add_node("retriever", similar_q) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("toolcaller", ToolNode(TOOLS)) | |
| # 边 | |
| builder.add_edge(START, "cache") | |
| builder.add_edge("cache", "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| # 工具调用与结束判断 | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| { | |
| "toolcaller": "toolcaller", | |
| END: END | |
| } | |
| ) | |
| builder.add_edge("toolcaller", "assistant") | |
| return builder.compile() | |
| # -------------------------------------------------------------------------------------- | |
| # 7. CLI 用法示例 | |
| # -------------------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import sys | |
| question = " ".join(sys.argv[1:]) or "Who won the Nobel physics prize in 1921?" | |
| graph = build_graph("groq") | |
| init_msgs = [SYSTEM_PROMPT, HumanMessage(content=question)] | |
| result = graph.invoke({"messages": init_msgs}) | |
| for msg in result["messages"]: | |
| msg.pretty_print() | |
| # 写入缓存 | |
| _CACHE[question] = result["messages"][-1] |