Spaces:
Sleeping
Sleeping
| """ | |
| agri_gpt_router.py | |
| GPT-driven router + answer refiner for AgriScholarQA. | |
| Flow: | |
| 1. Take a user question. | |
| 2. Use GPT to classify whether it is an *agricultural scholarly question*. | |
| 3. If NOT agri scholarly: | |
| - Do NOT call RAG. | |
| - Reply with a friendly explanation of the system + capabilities. | |
| 4. If agri scholarly: | |
| - Call RAG: rag.ask(question) -> {"answer": raw_answer, "evidence": [...]} | |
| - Send (question, raw_answer, evidence) to GPT to: | |
| * Fix repetition | |
| * Improve structure and formatting (Markdown) | |
| * Ground in evidence with [1], [2] citations. | |
| 5. Additionally, this class exposes: | |
| - start_session(...) | |
| - current_session property | |
| - conversation_manager property | |
| - retrieve_with_context(...) | |
| - validate_answer(...) | |
| so that existing code like `rag.current_session`, `rag.retrieve_with_context`, and | |
| `rag.validate_answer` keep working when you swap AgriCritiqueRAG -> GPTAgriRouter. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from dataclasses import dataclass, asdict | |
| from typing import Any, Dict, List, Optional | |
| # ---- OpenAI client (GPT) ---- | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| OpenAI = None | |
| # ---------------------------------------------------------------------- | |
| # Data structures | |
| # ---------------------------------------------------------------------- | |
| class GPTClassification: | |
| is_agri_scholarly: bool | |
| intent_type: str # "agri_scholarly" | "chit_chat" | "generic_qa" | "other" | |
| confidence: float | |
| brief_reason: str | |
| class RouterOutput: | |
| mode: str # "rag" | "system_chat" | "error" | |
| answer: str # final answer to show to user | |
| evidence: List[Dict[str, Any]] | |
| meta: Dict[str, Any] # includes classification, raw RAG result, etc. | |
| # ---------------------------------------------------------------------- | |
| # Main router | |
| # ---------------------------------------------------------------------- | |
| class GPTAgriRouter: | |
| """ | |
| Uses GPT to: | |
| 1) Route between RAG vs. system-chat. | |
| 2) Refine RAG answer to avoid repetition and improve formatting. | |
| It also behaves like a thin wrapper around your existing RAG: | |
| - start_session(...) | |
| - current_session | |
| - conversation_manager | |
| - retrieve_with_context(...) | |
| - validate_answer(...) | |
| so your `app.py` code that expects AgriCritiqueRAG continues to work. | |
| """ | |
| def __init__( | |
| self, | |
| rag_system: Any, | |
| gpt_model_classify: str = "gpt-4.1-mini", | |
| gpt_model_refine: Optional[str] = None, | |
| openai_api_key_env: str = "OPENAI_API_KEY", | |
| ): | |
| self.rag = rag_system | |
| self.gpt_model_classify = gpt_model_classify | |
| self.gpt_model_refine = gpt_model_refine or gpt_model_classify | |
| if OpenAI is None: | |
| raise ImportError( | |
| "openai library not installed. Run `pip install openai` in your environment/Space." | |
| ) | |
| api_key = os.getenv(openai_api_key_env) | |
| if not api_key: | |
| raise ValueError( | |
| f"{openai_api_key_env} is not set. Please add it as a secret (e.g., in your Space)." | |
| ) | |
| self.client = OpenAI(api_key=api_key) | |
| # ========= 1. Classification (GPT) ========= | |
| def _classify_with_gpt(self, question: str) -> GPTClassification: | |
| system_prompt = ( | |
| "You are a classifier for an agricultural research assistant called AgriScholarQA.\n\n" | |
| "Your job: given a single user query, decide whether it is an " | |
| "**agricultural scholarly question** that should trigger a retrieval-augmented " | |
| "pipeline over agricultural research papers.\n\n" | |
| "Definitions:\n" | |
| "- Agricultural scholarly question: asks about crops, soils, climate impacts, " | |
| " agronomy, plant physiology, agricultural experiments, yields, pests, diseases, " | |
| " fertilizers, irrigation, crop models, etc., in a technically informed way.\n" | |
| "- Chit-chat / meta: greetings, what is this system, who are you, etc.\n" | |
| "- Generic QA: everyday knowledge or non-agricultural topics (e.g., movies, math, generic ML).\n" | |
| "- Other: anything else not clearly fitting above.\n\n" | |
| "Return a strict JSON object with fields:\n" | |
| "- is_agri_scholarly: boolean\n" | |
| "- intent_type: one of \"agri_scholarly\", \"chit_chat\", \"generic_qa\", \"other\"\n" | |
| "- confidence: float between 0 and 1\n" | |
| "- brief_reason: short natural language reason (1–2 sentences)\n\n" | |
| "Do not add any extra keys. Do not write any explanations outside the JSON." | |
| ) | |
| user_prompt = f"User query:\n\"\"\"{question}\"\"\"" | |
| resp = self.client.chat.completions.create( | |
| model=self.gpt_model_classify, | |
| temperature=0, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| raw = resp.choices[0].message.content.strip() | |
| try: | |
| data = json.loads(raw) | |
| except json.JSONDecodeError as e: | |
| # Safest fallback: treat as non-agri so we don't waste RAG calls | |
| return GPTClassification( | |
| is_agri_scholarly=False, | |
| intent_type="other", | |
| confidence=0.0, | |
| brief_reason=f"Failed to parse GPT JSON: {e} | raw={raw[:200]}", | |
| ) | |
| return GPTClassification( | |
| is_agri_scholarly=bool(data.get("is_agri_scholarly", False)), | |
| intent_type=str(data.get("intent_type", "other")), | |
| confidence=float(data.get("confidence", 0.0)), | |
| brief_reason=str(data.get("brief_reason", "")), | |
| ) | |
| # ========= 2. Refine RAG answer (GPT) ========= | |
| def _refine_answer_with_gpt( | |
| self, | |
| question: str, | |
| raw_answer: str, | |
| evidence: List[Dict[str, Any]], | |
| ) -> str: | |
| # Build evidence summary (top 5 chunks) | |
| ev_blocks = [] | |
| for i, ev in enumerate(evidence[:5], 1): | |
| title = ev.get("paper_title") or ev.get("paper_id") or f"Doc {ev.get('idx', i)}" | |
| snippet = ev.get("text") or ev.get("text_preview") or "" | |
| snippet = " ".join(snippet.split()) | |
| snippet = snippet[:800] | |
| ev_blocks.append(f"[{i}] {title}\n{snippet}\n") | |
| evidence_text = "\n\n".join(ev_blocks) if ev_blocks else "(no evidence available)" | |
| system_prompt = ( | |
| "You are an expert agricultural research assistant.\n\n" | |
| "You are given:\n" | |
| "1) The user's question.\n" | |
| "2) A draft answer from another model (may be repetitive or imperfect).\n" | |
| "3) Evidence snippets from research papers.\n\n" | |
| "Your task:\n" | |
| "- Produce a SINGLE, clean, well-structured answer in Markdown.\n" | |
| "- Do NOT repeat the same idea or sentence multiple times.\n" | |
| "- Use headings or bullet points if helpful, but keep it concise.\n" | |
| "- Ground the answer in the evidence; when you use a fact from a snippet, cite it " | |
| " as [1], [2], etc.\n" | |
| "- Give the evidence as separte section metaining the citation order [1], [2], etc.\n" | |
| "- If the draft answer is clearly unsupported or contradicted by the evidence, correct it.\n" | |
| "- If evidence is insufficient, clearly state limitations instead of hallucinating.\n" | |
| ) | |
| user_prompt = ( | |
| f"QUESTION:\n{question}\n\n" | |
| f"DRAFT ANSWER:\n{raw_answer}\n\n" | |
| f"EVIDENCE SNIPPETS:\n{evidence_text}\n\n" | |
| "Now rewrite the answer following the instructions." | |
| ) | |
| resp = self.client.chat.completions.create( | |
| model=self.gpt_model_refine, | |
| temperature=0.3, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| # ========= 3. System chat for non-agri queries ========= | |
| def _system_chat_answer(self, question: str, cls: GPTClassification) -> str: | |
| intro = ( | |
| "Hi! 👋 I’m **AgriScholarQA**, an agricultural scholarly assistant.\n\n" | |
| "I’m designed specifically to answer **research-oriented questions about agriculture** " | |
| "using a retrieval-augmented pipeline over scientific papers." | |
| ) | |
| capabilities = ( | |
| "\n\n**Here’s what I can do:**\n" | |
| "- 📚 Answer questions about **crop production, soil, climate impacts, pests, diseases**, etc.\n" | |
| "- 🔍 Retrieve and show **evidence from agricultural research papers**.\n" | |
| "- 🧪 Help you reason about **field experiments, treatments, and agronomic practices**.\n" | |
| "- 🚨 Detect potential **hallucinations or weakly supported claims**.\n" | |
| ) | |
| if cls.intent_type in {"chit_chat", "generic_qa"}: | |
| meta = ( | |
| f"\nYour query looks like **{cls.intent_type.replace('_', ' ')}**, " | |
| "not a detailed agricultural research question, so I didn’t trigger " | |
| "the heavy evidence-retrieval pipeline this time.\n" | |
| ) | |
| else: | |
| meta = ( | |
| "\nYour query doesn’t look like an agricultural scholarly question, " | |
| "so I’m staying in simple chat mode instead of hitting the research index.\n" | |
| ) | |
| nudge = ( | |
| "\nIf you’d like to use my full power, try asking something like:\n" | |
| "- *“How does drought stress during flowering affect rice yield in irrigated systems?”*\n" | |
| "- *“What are sustainable pest management strategies for maize in tropical regions?”*\n" | |
| "- *“How does nitrogen fertilizer rate influence wheat grain protein under heat stress?”*\n" | |
| ) | |
| return intro + capabilities + meta + nudge | |
| # ========= 4. Public entry: handle_query ========= | |
| def handle_query(self, question: str) -> Dict[str, Any]: | |
| q = (question or "").strip() | |
| if not q: | |
| return { | |
| "mode": "system_chat", | |
| "answer": "Please enter a question. I specialize in agricultural research questions.", | |
| "evidence": [], | |
| "meta": {"classification": None}, | |
| } | |
| # 1. classify | |
| try: | |
| cls = self._classify_with_gpt(q) | |
| except Exception as e: | |
| return { | |
| "mode": "error", | |
| "answer": f"⚠️ Error while classifying your question with GPT: `{e}`", | |
| "evidence": [], | |
| "meta": {"classification": None}, | |
| } | |
| # 2. non-agri -> system chat | |
| if (not cls.is_agri_scholarly) or cls.confidence < 0.5: | |
| answer = self._system_chat_answer(q, cls) | |
| return { | |
| "mode": "system_chat", | |
| "answer": answer, | |
| "evidence": [], | |
| "meta": {"classification": asdict(cls)}, | |
| } | |
| # 3. agri -> RAG + refine | |
| try: | |
| rag_result = self.rag.ask(q) | |
| except Exception as e: | |
| return { | |
| "mode": "error", | |
| "answer": ( | |
| "Your question looks like an **agricultural scholarly query**, " | |
| "but I encountered an error while running the retrieval pipeline:\n\n" | |
| f"`{e}`" | |
| ), | |
| "evidence": [], | |
| "meta": {"classification": asdict(cls)}, | |
| } | |
| raw_answer = rag_result.get("answer", "") if isinstance(rag_result, dict) else str(rag_result) | |
| evidence = rag_result.get("evidence", []) if isinstance(rag_result, dict) else [] | |
| # try: | |
| # # refined_answer = self._refine_answer_with_gpt(q, raw_answer, evidence) | |
| # except Exception as e: | |
| refined_answer = ( | |
| f"⚠️ I had trouble refining the answer with GPT (`{e}`). " | |
| "Showing the original RAG answer below:\n\n" | |
| + raw_answer | |
| ) | |
| return { | |
| "mode": "rag", | |
| "answer": refined_answer, | |
| "evidence": evidence, | |
| "meta": { | |
| "classification": asdict(cls), | |
| "raw_rag_result": rag_result, | |
| }, | |
| } | |
| # ========= 5. Adapter methods/properties so old code still works ========= | |
| # These are what fix: 'GPTAgriRouter' object has no attribute 'current_session' | |
| # For app.py calls: if not rag.current_session: rag.start_session() | |
| def current_session(self): | |
| return getattr(self.rag, "current_session", None) | |
| def start_session(self, metadata=None): | |
| """ | |
| Delegate to underlying RAG so existing session logic works. | |
| """ | |
| if hasattr(self.rag, "start_session"): | |
| return self.rag.start_session(metadata) | |
| return None | |
| def conversation_manager(self): | |
| """ | |
| For code using: rag.conversation_manager.get_session_history(...) | |
| """ | |
| return getattr(self.rag, "conversation_manager", None) | |
| def retrieve_with_context(self, question: str, history: List[Dict[str, Any]], top_k: int = 5): | |
| """ | |
| Pass-through for validation mode, etc. | |
| """ | |
| if hasattr(self.rag, "retrieve_with_context"): | |
| return self.rag.retrieve_with_context(question, history, top_k=top_k) | |
| return [] | |
| def validate_answer(self, question: str, proposed_answer: str, evidence: List[Dict[str, Any]]) -> str: | |
| """ | |
| Pass-through so your existing validate_answer_button can still call rag.validate_answer(...) | |
| """ | |
| if hasattr(self.rag, "validate_answer"): | |
| return self.rag.validate_answer(question, proposed_answer, evidence) | |
| # Fallback heuristic if underlying rag has no validate method | |
| return "Validation not available: underlying RAG has no 'validate_answer' method." | |