imurra commited on
Commit
a9d88a0
·
verified ·
1 Parent(s): 2c2aa6b

unzip database

Browse files
Files changed (1) hide show
  1. app.py +12 -75
app.py CHANGED
@@ -1,5 +1,6 @@
1
- # app.py - Hugging Face Spaces version
2
  import os
 
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
@@ -9,11 +10,20 @@ import gradio as gr
9
 
10
  # Database path
11
  DB_PATH = "./medqa_db"
 
 
 
 
 
 
 
 
12
 
13
  # Initialize
14
  print(f"Loading database from: {DB_PATH}")
15
  client = chromadb.PersistentClient(path=DB_PATH)
16
  collection = client.get_collection("medqa")
 
17
  print(f"Loading MedCPT model...")
18
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
19
  print("Initialization complete!")
@@ -34,77 +44,4 @@ class SearchRequest(BaseModel):
34
  num_results: int = 3
35
 
36
  class SearchResponse(BaseModel):
37
- results: list[dict]
38
-
39
- @app.get("/")
40
- async def root():
41
- return {
42
- "message": "MedQA Search API - Hugging Face Version",
43
- "status": "running",
44
- "collection_count": collection.count()
45
- }
46
-
47
- @app.post("/search_medqa", response_model=SearchResponse)
48
- async def search_medqa(request: SearchRequest):
49
- """Search MedQA database for similar USMLE questions"""
50
- try:
51
- embedding = model.encode(request.query).tolist()
52
- results = collection.query(
53
- query_embeddings=[embedding],
54
- n_results=request.num_results
55
- )
56
-
57
- formatted_results = []
58
- for i in range(len(results['documents'][0])):
59
- formatted_results.append({
60
- "example_number": i + 1,
61
- "question": results['documents'][0][i],
62
- "answer": results['metadatas'][0][i].get('answer', 'N/A'),
63
- "distance": results['distances'][0][i] if 'distances' in results else None
64
- })
65
-
66
- return SearchResponse(results=formatted_results)
67
- except Exception as e:
68
- raise HTTPException(status_code=500, detail=str(e))
69
-
70
- # Gradio interface (optional - gives you a web UI)
71
- def search_interface(query: str, num_results: int = 3):
72
- """Simple web interface for testing"""
73
- try:
74
- embedding = model.encode(query).tolist()
75
- results = collection.query(
76
- query_embeddings=[embedding],
77
- n_results=num_results
78
- )
79
-
80
- output = ""
81
- for i in range(len(results['documents'][0])):
82
- output += f"\n{'='*60}\n"
83
- output += f"Example {i+1}\n"
84
- output += f"{'='*60}\n"
85
- output += results['documents'][0][i] + "\n"
86
- output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
87
- output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
88
-
89
- return output
90
- except Exception as e:
91
- return f"Error: {str(e)}"
92
-
93
- # Create Gradio interface
94
- demo = gr.Interface(
95
- fn=search_interface,
96
- inputs=[
97
- gr.Textbox(label="Medical Topic or Clinical Scenario", placeholder="e.g., hyponatremia"),
98
- gr.Slider(1, 5, value=3, step=1, label="Number of Examples")
99
- ],
100
- outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
101
- title="MedQA Search - USMLE Question Database",
102
- description="Search for similar USMLE Step 1 questions using semantic similarity"
103
- )
104
-
105
- # Mount Gradio app and FastAPI
106
- app = gr.mount_gradio_app(app, demo, path="/")
107
-
108
- if __name__ == "__main__":
109
- import uvicorn
110
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ # app.py - Hugging Face Spaces version with auto-extract
2
  import os
3
+ import zipfile
4
  from fastapi import FastAPI, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
 
10
 
11
  # Database path
12
  DB_PATH = "./medqa_db"
13
+ ZIP_PATH = "./medqa_db.zip"
14
+
15
+ # Extract database if needed
16
+ if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH):
17
+ print("Extracting database from zip file...")
18
+ with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
19
+ zip_ref.extractall(".")
20
+ print("Database extracted successfully!")
21
 
22
  # Initialize
23
  print(f"Loading database from: {DB_PATH}")
24
  client = chromadb.PersistentClient(path=DB_PATH)
25
  collection = client.get_collection("medqa")
26
+ print(f"Collection loaded with {collection.count()} items")
27
  print(f"Loading MedCPT model...")
28
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
29
  print("Initialization complete!")
 
44
  num_results: int = 3
45
 
46
  class SearchResponse(BaseModel):
47
+ results: list