AgriScholarQA / agri_gpt_router.py
sayande's picture
Update agri_gpt_router.py
329fdb9 verified
"""
agri_gpt_router.py
GPT-driven router + answer refiner for AgriScholarQA.
Flow:
1. Take a user question.
2. Use GPT to classify whether it is an *agricultural scholarly question*.
3. If NOT agri scholarly:
- Do NOT call RAG.
- Reply with a friendly explanation of the system + capabilities.
4. If agri scholarly:
- Call RAG: rag.ask(question) -> {"answer": raw_answer, "evidence": [...]}
- Send (question, raw_answer, evidence) to GPT to:
* Fix repetition
* Improve structure and formatting (Markdown)
* Ground in evidence with [1], [2] citations.
5. Additionally, this class exposes:
- start_session(...)
- current_session property
- conversation_manager property
- retrieve_with_context(...)
- validate_answer(...)
so that existing code like `rag.current_session`, `rag.retrieve_with_context`, and
`rag.validate_answer` keep working when you swap AgriCritiqueRAG -> GPTAgriRouter.
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional
# ---- OpenAI client (GPT) ----
try:
from openai import OpenAI
except ImportError:
OpenAI = None
# ----------------------------------------------------------------------
# Data structures
# ----------------------------------------------------------------------
@dataclass
class GPTClassification:
is_agri_scholarly: bool
intent_type: str # "agri_scholarly" | "chit_chat" | "generic_qa" | "other"
confidence: float
brief_reason: str
@dataclass
class RouterOutput:
mode: str # "rag" | "system_chat" | "error"
answer: str # final answer to show to user
evidence: List[Dict[str, Any]]
meta: Dict[str, Any] # includes classification, raw RAG result, etc.
# ----------------------------------------------------------------------
# Main router
# ----------------------------------------------------------------------
class GPTAgriRouter:
"""
Uses GPT to:
1) Route between RAG vs. system-chat.
2) Refine RAG answer to avoid repetition and improve formatting.
It also behaves like a thin wrapper around your existing RAG:
- start_session(...)
- current_session
- conversation_manager
- retrieve_with_context(...)
- validate_answer(...)
so your `app.py` code that expects AgriCritiqueRAG continues to work.
"""
def __init__(
self,
rag_system: Any,
gpt_model_classify: str = "gpt-4.1-mini",
gpt_model_refine: Optional[str] = None,
openai_api_key_env: str = "OPENAI_API_KEY",
):
self.rag = rag_system
self.gpt_model_classify = gpt_model_classify
self.gpt_model_refine = gpt_model_refine or gpt_model_classify
if OpenAI is None:
raise ImportError(
"openai library not installed. Run `pip install openai` in your environment/Space."
)
api_key = os.getenv(openai_api_key_env)
if not api_key:
raise ValueError(
f"{openai_api_key_env} is not set. Please add it as a secret (e.g., in your Space)."
)
self.client = OpenAI(api_key=api_key)
# ========= 1. Classification (GPT) =========
def _classify_with_gpt(self, question: str) -> GPTClassification:
system_prompt = (
"You are a classifier for an agricultural research assistant called AgriScholarQA.\n\n"
"Your job: given a single user query, decide whether it is an "
"**agricultural scholarly question** that should trigger a retrieval-augmented "
"pipeline over agricultural research papers.\n\n"
"Definitions:\n"
"- Agricultural scholarly question: asks about crops, soils, climate impacts, "
" agronomy, plant physiology, agricultural experiments, yields, pests, diseases, "
" fertilizers, irrigation, crop models, etc., in a technically informed way.\n"
"- Chit-chat / meta: greetings, what is this system, who are you, etc.\n"
"- Generic QA: everyday knowledge or non-agricultural topics (e.g., movies, math, generic ML).\n"
"- Other: anything else not clearly fitting above.\n\n"
"Return a strict JSON object with fields:\n"
"- is_agri_scholarly: boolean\n"
"- intent_type: one of \"agri_scholarly\", \"chit_chat\", \"generic_qa\", \"other\"\n"
"- confidence: float between 0 and 1\n"
"- brief_reason: short natural language reason (1–2 sentences)\n\n"
"Do not add any extra keys. Do not write any explanations outside the JSON."
)
user_prompt = f"User query:\n\"\"\"{question}\"\"\""
resp = self.client.chat.completions.create(
model=self.gpt_model_classify,
temperature=0,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
raw = resp.choices[0].message.content.strip()
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
# Safest fallback: treat as non-agri so we don't waste RAG calls
return GPTClassification(
is_agri_scholarly=False,
intent_type="other",
confidence=0.0,
brief_reason=f"Failed to parse GPT JSON: {e} | raw={raw[:200]}",
)
return GPTClassification(
is_agri_scholarly=bool(data.get("is_agri_scholarly", False)),
intent_type=str(data.get("intent_type", "other")),
confidence=float(data.get("confidence", 0.0)),
brief_reason=str(data.get("brief_reason", "")),
)
# ========= 2. Refine RAG answer (GPT) =========
def _refine_answer_with_gpt(
self,
question: str,
raw_answer: str,
evidence: List[Dict[str, Any]],
) -> str:
# Build evidence summary (top 5 chunks)
ev_blocks = []
for i, ev in enumerate(evidence[:5], 1):
title = ev.get("paper_title") or ev.get("paper_id") or f"Doc {ev.get('idx', i)}"
snippet = ev.get("text") or ev.get("text_preview") or ""
snippet = " ".join(snippet.split())
snippet = snippet[:800]
ev_blocks.append(f"[{i}] {title}\n{snippet}\n")
evidence_text = "\n\n".join(ev_blocks) if ev_blocks else "(no evidence available)"
system_prompt = (
"You are an expert agricultural research assistant.\n\n"
"You are given:\n"
"1) The user's question.\n"
"2) A draft answer from another model (may be repetitive or imperfect).\n"
"3) Evidence snippets from research papers.\n\n"
"Your task:\n"
"- Produce a SINGLE, clean, well-structured answer in Markdown.\n"
"- Do NOT repeat the same idea or sentence multiple times.\n"
"- Use headings or bullet points if helpful, but keep it concise.\n"
"- Ground the answer in the evidence; when you use a fact from a snippet, cite it "
" as [1], [2], etc.\n"
"- Give the evidence as separte section metaining the citation order [1], [2], etc.\n"
"- If the draft answer is clearly unsupported or contradicted by the evidence, correct it.\n"
"- If evidence is insufficient, clearly state limitations instead of hallucinating.\n"
)
user_prompt = (
f"QUESTION:\n{question}\n\n"
f"DRAFT ANSWER:\n{raw_answer}\n\n"
f"EVIDENCE SNIPPETS:\n{evidence_text}\n\n"
"Now rewrite the answer following the instructions."
)
resp = self.client.chat.completions.create(
model=self.gpt_model_refine,
temperature=0.3,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
return resp.choices[0].message.content.strip()
# ========= 3. System chat for non-agri queries =========
def _system_chat_answer(self, question: str, cls: GPTClassification) -> str:
intro = (
"Hi! 👋 I’m **AgriScholarQA**, an agricultural scholarly assistant.\n\n"
"I’m designed specifically to answer **research-oriented questions about agriculture** "
"using a retrieval-augmented pipeline over scientific papers."
)
capabilities = (
"\n\n**Here’s what I can do:**\n"
"- 📚 Answer questions about **crop production, soil, climate impacts, pests, diseases**, etc.\n"
"- 🔍 Retrieve and show **evidence from agricultural research papers**.\n"
"- 🧪 Help you reason about **field experiments, treatments, and agronomic practices**.\n"
"- 🚨 Detect potential **hallucinations or weakly supported claims**.\n"
)
if cls.intent_type in {"chit_chat", "generic_qa"}:
meta = (
f"\nYour query looks like **{cls.intent_type.replace('_', ' ')}**, "
"not a detailed agricultural research question, so I didn’t trigger "
"the heavy evidence-retrieval pipeline this time.\n"
)
else:
meta = (
"\nYour query doesn’t look like an agricultural scholarly question, "
"so I’m staying in simple chat mode instead of hitting the research index.\n"
)
nudge = (
"\nIf you’d like to use my full power, try asking something like:\n"
"- *“How does drought stress during flowering affect rice yield in irrigated systems?”*\n"
"- *“What are sustainable pest management strategies for maize in tropical regions?”*\n"
"- *“How does nitrogen fertilizer rate influence wheat grain protein under heat stress?”*\n"
)
return intro + capabilities + meta + nudge
# ========= 4. Public entry: handle_query =========
def handle_query(self, question: str) -> Dict[str, Any]:
q = (question or "").strip()
if not q:
return {
"mode": "system_chat",
"answer": "Please enter a question. I specialize in agricultural research questions.",
"evidence": [],
"meta": {"classification": None},
}
# 1. classify
try:
cls = self._classify_with_gpt(q)
except Exception as e:
return {
"mode": "error",
"answer": f"⚠️ Error while classifying your question with GPT: `{e}`",
"evidence": [],
"meta": {"classification": None},
}
# 2. non-agri -> system chat
if (not cls.is_agri_scholarly) or cls.confidence < 0.5:
answer = self._system_chat_answer(q, cls)
return {
"mode": "system_chat",
"answer": answer,
"evidence": [],
"meta": {"classification": asdict(cls)},
}
# 3. agri -> RAG + refine
try:
rag_result = self.rag.ask(q)
except Exception as e:
return {
"mode": "error",
"answer": (
"Your question looks like an **agricultural scholarly query**, "
"but I encountered an error while running the retrieval pipeline:\n\n"
f"`{e}`"
),
"evidence": [],
"meta": {"classification": asdict(cls)},
}
raw_answer = rag_result.get("answer", "") if isinstance(rag_result, dict) else str(rag_result)
evidence = rag_result.get("evidence", []) if isinstance(rag_result, dict) else []
# try:
# # refined_answer = self._refine_answer_with_gpt(q, raw_answer, evidence)
# except Exception as e:
refined_answer = (
f"⚠️ I had trouble refining the answer with GPT (`{e}`). "
"Showing the original RAG answer below:\n\n"
+ raw_answer
)
return {
"mode": "rag",
"answer": refined_answer,
"evidence": evidence,
"meta": {
"classification": asdict(cls),
"raw_rag_result": rag_result,
},
}
# ========= 5. Adapter methods/properties so old code still works =========
# These are what fix: 'GPTAgriRouter' object has no attribute 'current_session'
# For app.py calls: if not rag.current_session: rag.start_session()
@property
def current_session(self):
return getattr(self.rag, "current_session", None)
def start_session(self, metadata=None):
"""
Delegate to underlying RAG so existing session logic works.
"""
if hasattr(self.rag, "start_session"):
return self.rag.start_session(metadata)
return None
@property
def conversation_manager(self):
"""
For code using: rag.conversation_manager.get_session_history(...)
"""
return getattr(self.rag, "conversation_manager", None)
def retrieve_with_context(self, question: str, history: List[Dict[str, Any]], top_k: int = 5):
"""
Pass-through for validation mode, etc.
"""
if hasattr(self.rag, "retrieve_with_context"):
return self.rag.retrieve_with_context(question, history, top_k=top_k)
return []
def validate_answer(self, question: str, proposed_answer: str, evidence: List[Dict[str, Any]]) -> str:
"""
Pass-through so your existing validate_answer_button can still call rag.validate_answer(...)
"""
if hasattr(self.rag, "validate_answer"):
return self.rag.validate_answer(question, proposed_answer, evidence)
# Fallback heuristic if underlying rag has no validate method
return "Validation not available: underlying RAG has no 'validate_answer' method."