| ''' | |
| import os | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| app = FastAPI() | |
| # Define the cache directory path within your project | |
| cache_dir = str(Path(__file__).parent.resolve() / 'cache') | |
| # Create the cache directory if it doesn't exist | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Set the TRANSFORMERS_CACHE environment variable to the cache directory | |
| os.environ['TRANSFORMERS_CACHE'] = cache_dir | |
| print(f"Transformers cache directory: {os.environ['TRANSFORMERS_CACHE']}") | |
| # Ensure your Hugging Face token is set as an environment variable | |
| huggingface_token = os.environ.get("TOKEN") | |
| if not huggingface_token: | |
| raise ValueError("TOKEN environment variable is not set.") | |
| # Load the tokenizer and model using Hugging Face's library with the token | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token) | |
| model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token=huggingface_token) | |
| # Initialize the pipeline | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=0 # Assuming you're using a GPU, otherwise set to -1 for CPU | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {e}") | |
| # Data model for the request body | |
| class Item(BaseModel): | |
| prompt: str | |
| temperature: float = 0.7 | |
| max_new_tokens: int = 128 | |
| # Endpoint for generating text | |
| @app.post("/") | |
| async def generate_text(item: Item): | |
| try: | |
| if not item.prompt: | |
| raise HTTPException(status_code=400, detail="`prompt` field is required") | |
| output = generator( | |
| item.prompt, | |
| temperature=item.temperature, | |
| max_new_tokens=item.max_new_tokens, | |
| ) | |
| return {"generated_text": output[0]['generated_text']} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred: {e}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| ''' |