PinkAlpaca's picture
Update disabled2.py
48d18a7 verified
'''
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
# Define the cache directory path within your project
cache_dir = str(Path(__file__).parent.resolve() / 'cache')
# Create the cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# Set the TRANSFORMERS_CACHE environment variable to the cache directory
os.environ['TRANSFORMERS_CACHE'] = cache_dir
print(f"Transformers cache directory: {os.environ['TRANSFORMERS_CACHE']}")
# Ensure your Hugging Face token is set as an environment variable
huggingface_token = os.environ.get("TOKEN")
if not huggingface_token:
raise ValueError("TOKEN environment variable is not set.")
# Load the tokenizer and model using Hugging Face's library with the token
try:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
# Initialize the pipeline
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 # Assuming you're using a GPU, otherwise set to -1 for CPU
)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
# Data model for the request body
class Item(BaseModel):
prompt: str
temperature: float = 0.7
max_new_tokens: int = 128
# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item):
try:
if not item.prompt:
raise HTTPException(status_code=400, detail="`prompt` field is required")
output = generator(
item.prompt,
temperature=item.temperature,
max_new_tokens=item.max_new_tokens,
)
return {"generated_text": output[0]['generated_text']}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
'''