File size: 6,780 Bytes
4f895cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
agent_variant.py
第四种实现方式:  
* 自动收集所有 @tool 装饰的函数  
* 内置简单内存缓存,避免同问多答  
* 将“相似问题检索”与“工具调用”全部拆成独立节点  
* 采用工厂模式生成 LLM,支持 google / groq / huggingface 三种 provider
"""

from __future__ import annotations

import os
from typing import Dict, List, Any

from dotenv import load_dotenv
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings

from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import SupabaseVectorStore
from supabase.client import create_client

load_dotenv()

# --------------------------------------------------------------------------------------
# 1. 工具定义(算术 & 检索)
# --------------------------------------------------------------------------------------
@tool
def multiply(x: int, y: int) -> int:
    """Return x * y."""
    return x * y

@tool
def plus(x: int, y: int) -> int:
    """Return x + y."""
    return x + y

@tool
def minus(x: int, y: int) -> int:
    """Return x - y."""
    return x - y

@tool
def divide(x: int, y: int) -> float:
    """Return x / y; raise if y==0."""
    if y == 0:
        raise ValueError("Cannot divide by zero.")
    return x / y

@tool
def modulo(x: int, y: int) -> int:
    """Return x % y."""
    return x % y

@tool
def wiki(query: str) -> str:
    """Top-2 Wikipedia docs for *query*."""
    docs = WikipediaLoader(query=query, load_max_docs=2).load()
    return "\n\n---\n\n".join(d.page_content for d in docs)

@tool
def tavily(query: str) -> str:
    """Top-3 Tavily results for *query*."""
    docs = TavilySearchResults(max_results=3).invoke(query=query)
    return "\n\n---\n\n".join(d.page_content for d in docs)

@tool
def arxiv(query: str) -> str:
    """Top-3 arXiv abstracts for *query*."""
    docs = ArxivLoader(query=query, load_max_docs=3).load()
    return "\n\n---\n\n".join(d.page_content[:1000] for d in docs)

# 自动收集全部工具
TOOLS = [obj for obj in globals().values() if callable(obj) and getattr(obj, "__tool", False)]

# --------------------------------------------------------------------------------------
# 2. 简易内存级缓存,避免重复计算
# --------------------------------------------------------------------------------------
_CACHE: Dict[str, AIMessage] = {}

def cache_lookup(state: MessagesState) -> Dict[str, List[Any]]:
    question = state["messages"][-1].content
    if question in _CACHE:            # 命中缓存,直接返回
        return {"messages": [_CACHE[question]]}
    return state                      # miss 时原样透传

# --------------------------------------------------------------------------------------
# 3. 向量检索(Supabase)作为“相似问题提示”
# --------------------------------------------------------------------------------------
def build_vector_store() -> SupabaseVectorStore:
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
    client = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"])
    return SupabaseVectorStore(
        client=client,
        embedding=embeddings,
        table_name="documents",
        query_name="match_documents_langchain",
    )

VECTOR_STORE = build_vector_store()

def similar_q(state: MessagesState):
    query = state["messages"][-1].content
    hits = VECTOR_STORE.similarity_search(query, k=1)
    if hits:
        hint = AIMessage(content=f"(参考答案)\n{hits[0].page_content}")
        return {"messages": [hint]}
    return state

# --------------------------------------------------------------------------------------
# 4. LLM 工厂
# --------------------------------------------------------------------------------------
def make_llm(provider: str):
    if provider == "google":
        return ChatGoogleGenerativeAI(model="gemini-2.0-pro", temperature=0)
    if provider == "groq":
        return ChatGroq(model="llama3-70b-8192", temperature=0)
    if provider == "huggingface":
        return ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url="/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2Fmicrosoft%2FPhi-3-mini-4k-instruct%26quot%3B%3C%2Fspan%3E%2C
                temperature=0
            )
        )
    raise ValueError(f"Unknown provider: {provider}")

# --------------------------------------------------------------------------------------
# 5. 加载系统提示
# --------------------------------------------------------------------------------------
with open("system_prompt.txt", encoding="utf-8") as fp:
    SYSTEM_PROMPT = SystemMessage(content=fp.read())

# --------------------------------------------------------------------------------------
# 6. 构建 LangGraph
# --------------------------------------------------------------------------------------
def build_graph(provider: str = "groq"):
    llm = make_llm(provider).bind_tools(TOOLS)

    def assistant(state: MessagesState):
        return {"messages": [llm.invoke(state["messages"])]}

    builder = StateGraph(MessagesState)

    # 节点
    builder.add_node("cache", cache_lookup)
    builder.add_node("retriever", similar_q)
    builder.add_node("assistant", assistant)
    builder.add_node("toolcaller", ToolNode(TOOLS))

    # 边
    builder.add_edge(START, "cache")
    builder.add_edge("cache", "retriever")
    builder.add_edge("retriever", "assistant")

    # 工具调用与结束判断
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
        {
            "toolcaller": "toolcaller",
            END: END
        }
    )
    builder.add_edge("toolcaller", "assistant")

    return builder.compile()

# --------------------------------------------------------------------------------------
# 7. CLI 用法示例
# --------------------------------------------------------------------------------------
if __name__ == "__main__":
    import sys
    question = " ".join(sys.argv[1:]) or "Who won the Nobel physics prize in 1921?"
    graph = build_graph("groq")
    init_msgs = [SYSTEM_PROMPT, HumanMessage(content=question)]
    result = graph.invoke({"messages": init_msgs})
    for msg in result["messages"]:
        msg.pretty_print()
    # 写入缓存
    _CACHE[question] = result["messages"][-1]