Spaces:
Sleeping
Sleeping
File size: 7,545 Bytes
5cd3431 2ed7fce 5cd3431 2ed7fce 5cd3431 ae920ed 5cd3431 ae920ed 5cd3431 ae920ed 5cd3431 ae920ed 5cd3431 ae920ed 5cd3431 ae920ed 5cd3431 ae920ed 5cd3431 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
from fastapi import FastAPI, UploadFile, File, HTTPException, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List, Optional, Dict, AsyncGenerator
import os
from dotenv import load_dotenv
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.text_utils import CharacterTextSplitter, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
import asyncio
import tempfile
import shutil
import json
from uuid import uuid4
# Load environment variables
load_dotenv()
app = FastAPI()
# Mount static files
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chat_openai = ChatOpenAI()
# Define prompts
system_template = """\
You are a helpful assistant that provides concise, direct answers based on the provided context.
If the answer cannot be found in the context, simply say "I don't know" or "The information is not available in the provided context."
Keep your answers brief and to the point."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
Answer the question concisely based on the context above."""
user_role_prompt = UserRolePrompt(user_prompt_template)
# Session management
sessions: Dict[str, Dict] = {}
class Query(BaseModel):
text: str
k: int = 4
class DocumentResponse(BaseModel):
text: str
type: str # 'answer' or 'context'
score: Optional[float] = None
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str, k: int = 4) -> AsyncGenerator[str, None]:
# Get top k most relevant chunks
context_list = self.vector_db_retriever.search_by_text(user_query, k=k)
# Format context
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
# Format prompts
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(
question=user_query,
context=context_prompt
)
# Stream only the LLM response
async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
yield json.dumps({
"type": "token",
"text": chunk
})
# Send context information once at the end
yield json.dumps({
"type": "context",
"context": [{"text": text, "score": score} for text, score in context_list]
})
def process_file(file_path: str, file_name: str):
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
raise HTTPException(status_code=400, detail="Only PDF files are supported")
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
@app.post("/upload")
async def upload_document(file: UploadFile = File(...)):
if not file.filename.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="Only PDF files are supported")
try:
# Read the file content directly into memory
content = await file.read()
# Create a temporary file in a directory we know exists
temp_dir = "/tmp" # Using /tmp which is writable in most environments
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, f"upload_{file.filename}")
# Write the content to the temporary file
with open(temp_path, 'wb') as temp_file:
temp_file.write(content)
try:
# Process the file
texts = process_file(temp_path, file.filename)
# Create a new session
session_id = str(uuid4())
vector_db = VectorDatabase()
await vector_db.abuild_from_list(texts)
# Store session data
sessions[session_id] = {
"vector_db": vector_db,
"texts": texts
}
return {
"session_id": session_id,
"message": f"Document processed successfully. Added {len(texts)} chunks to the database."
}
finally:
# Clean up the temporary file
try:
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
print(f"Warning: Could not delete temporary file: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
@app.post("/query/{session_id}")
async def query_documents(session_id: str, query: Query):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
try:
session = sessions[session_id]
vector_db = session["vector_db"]
# Initialize RAG pipeline
rag_pipeline = RetrievalAugmentedQAPipeline(
llm=chat_openai,
vector_db_retriever=vector_db
)
# Create streaming response
async def generate():
async for chunk in rag_pipeline.arun_pipeline(query.text, query.k):
yield f"data: {chunk}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
await websocket.accept()
if session_id not in sessions:
await websocket.close(code=1008, reason="Session not found")
return
try:
session = sessions[session_id]
vector_db = session["vector_db"]
while True:
data = await websocket.receive_text()
query = json.loads(data)
# Initialize RAG pipeline
rag_pipeline = RetrievalAugmentedQAPipeline(
llm=chat_openai,
vector_db_retriever=vector_db
)
# Stream response
async for chunk in rag_pipeline.arun_pipeline(query["text"], query.get("k", 4)):
await websocket.send_text(json.dumps({
"type": "token" if isinstance(chunk, str) else "context",
"text": chunk if isinstance(chunk, str) else chunk
}))
except Exception as e:
await websocket.close(code=1011, reason=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9000) |