Spaces:
Runtime error
Runtime error
File size: 6,780 Bytes
4f895cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
agent_variant.py
第四种实现方式:
* 自动收集所有 @tool 装饰的函数
* 内置简单内存缓存,避免同问多答
* 将“相似问题检索”与“工具调用”全部拆成独立节点
* 采用工厂模式生成 LLM,支持 google / groq / huggingface 三种 provider
"""
from __future__ import annotations
import os
from typing import Dict, List, Any
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import SupabaseVectorStore
from supabase.client import create_client
load_dotenv()
# --------------------------------------------------------------------------------------
# 1. 工具定义(算术 & 检索)
# --------------------------------------------------------------------------------------
@tool
def multiply(x: int, y: int) -> int:
"""Return x * y."""
return x * y
@tool
def plus(x: int, y: int) -> int:
"""Return x + y."""
return x + y
@tool
def minus(x: int, y: int) -> int:
"""Return x - y."""
return x - y
@tool
def divide(x: int, y: int) -> float:
"""Return x / y; raise if y==0."""
if y == 0:
raise ValueError("Cannot divide by zero.")
return x / y
@tool
def modulo(x: int, y: int) -> int:
"""Return x % y."""
return x % y
@tool
def wiki(query: str) -> str:
"""Top-2 Wikipedia docs for *query*."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(d.page_content for d in docs)
@tool
def tavily(query: str) -> str:
"""Top-3 Tavily results for *query*."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
return "\n\n---\n\n".join(d.page_content for d in docs)
@tool
def arxiv(query: str) -> str:
"""Top-3 arXiv abstracts for *query*."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(d.page_content[:1000] for d in docs)
# 自动收集全部工具
TOOLS = [obj for obj in globals().values() if callable(obj) and getattr(obj, "__tool", False)]
# --------------------------------------------------------------------------------------
# 2. 简易内存级缓存,避免重复计算
# --------------------------------------------------------------------------------------
_CACHE: Dict[str, AIMessage] = {}
def cache_lookup(state: MessagesState) -> Dict[str, List[Any]]:
question = state["messages"][-1].content
if question in _CACHE: # 命中缓存,直接返回
return {"messages": [_CACHE[question]]}
return state # miss 时原样透传
# --------------------------------------------------------------------------------------
# 3. 向量检索(Supabase)作为“相似问题提示”
# --------------------------------------------------------------------------------------
def build_vector_store() -> SupabaseVectorStore:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
client = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"])
return SupabaseVectorStore(
client=client,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
VECTOR_STORE = build_vector_store()
def similar_q(state: MessagesState):
query = state["messages"][-1].content
hits = VECTOR_STORE.similarity_search(query, k=1)
if hits:
hint = AIMessage(content=f"(参考答案)\n{hits[0].page_content}")
return {"messages": [hint]}
return state
# --------------------------------------------------------------------------------------
# 4. LLM 工厂
# --------------------------------------------------------------------------------------
def make_llm(provider: str):
if provider == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-pro", temperature=0)
if provider == "groq":
return ChatGroq(model="llama3-70b-8192", temperature=0)
if provider == "huggingface":
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2Fmicrosoft%2FPhi-3-mini-4k-instruct%26quot%3B%3C%2Fspan%3E%2C
temperature=0
)
)
raise ValueError(f"Unknown provider: {provider}")
# --------------------------------------------------------------------------------------
# 5. 加载系统提示
# --------------------------------------------------------------------------------------
with open("system_prompt.txt", encoding="utf-8") as fp:
SYSTEM_PROMPT = SystemMessage(content=fp.read())
# --------------------------------------------------------------------------------------
# 6. 构建 LangGraph
# --------------------------------------------------------------------------------------
def build_graph(provider: str = "groq"):
llm = make_llm(provider).bind_tools(TOOLS)
def assistant(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
builder = StateGraph(MessagesState)
# 节点
builder.add_node("cache", cache_lookup)
builder.add_node("retriever", similar_q)
builder.add_node("assistant", assistant)
builder.add_node("toolcaller", ToolNode(TOOLS))
# 边
builder.add_edge(START, "cache")
builder.add_edge("cache", "retriever")
builder.add_edge("retriever", "assistant")
# 工具调用与结束判断
builder.add_conditional_edges(
"assistant",
tools_condition,
{
"toolcaller": "toolcaller",
END: END
}
)
builder.add_edge("toolcaller", "assistant")
return builder.compile()
# --------------------------------------------------------------------------------------
# 7. CLI 用法示例
# --------------------------------------------------------------------------------------
if __name__ == "__main__":
import sys
question = " ".join(sys.argv[1:]) or "Who won the Nobel physics prize in 1921?"
graph = build_graph("groq")
init_msgs = [SYSTEM_PROMPT, HumanMessage(content=question)]
result = graph.invoke({"messages": init_msgs})
for msg in result["messages"]:
msg.pretty_print()
# 写入缓存
_CACHE[question] = result["messages"][-1] |