Spaces:
Running
Running
| import os | |
| from openai import OpenAI, AsyncOpenAI | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, Literal | |
| from modules.constants import PROMPT_LIBRARY | |
| SYSTEM_PROMPT = """ | |
| You are an e-commerce fashion catalog assistant. | |
| Classify products and generate detailed descriptions based on images. | |
| """ | |
| USER_PROMPT = """ | |
| Analyze this fashion product image and provide: | |
| 1) Master category, 2) Gender, 3) Sub-category, and 4) A detailed description. | |
| """ | |
| class ProductClassification(BaseModel): | |
| """Structured output model for fashion product classification and description""" | |
| master_category: Literal["Footwear", "Accessories", "Apparel", "Personal Care"] = ( | |
| Field(description="The master category of the product") | |
| ) | |
| gender: Literal["Men", "Women", "Unisex", "Boys", "Girls"] = Field( | |
| description="The target gender for the product" | |
| ) | |
| sub_category: Literal[ | |
| "Sandal", | |
| "Scarves", | |
| "Shoes", | |
| "Watches", | |
| "Innerwear", | |
| "Topwear", | |
| "Belts", | |
| "Bags", | |
| "Flip Flops", | |
| "Nails", | |
| "Bottomwear", | |
| "Fragrance", | |
| "Wallets", | |
| "Jewellery", | |
| "Loungewear and Nightwear", | |
| "Socks", | |
| "Headwear", | |
| "Lips", | |
| "Saree", | |
| "Ties", | |
| "Accessories", | |
| "Eyewear", | |
| "Dress", | |
| "Skin Care", | |
| "Stoles", | |
| "Makeup", | |
| "Cufflinks", | |
| "Skin", | |
| "Hair", | |
| "Apparel Set", | |
| "Water Bottle", | |
| "Eyes", | |
| "Shoe Accessories", | |
| "Umbrellas", | |
| "Mufflers", | |
| "Beauty Accessories", | |
| "Gloves", | |
| "Sports Accessories", | |
| "Perfumes", | |
| "Bath and Body", | |
| ] = Field(description="The specific sub-category of the product") | |
| description: str = Field( | |
| description="A detailed description of the product based on the image" | |
| ) | |
| def analyze_product_image( | |
| image_url: str, | |
| model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct", | |
| api_key: Optional[str] = None, | |
| provider: str = "Fireworks", | |
| prompt_style: Optional[str] = None, | |
| ) -> ProductClassification: | |
| """ | |
| Analyze a fashion product image using VLM with structured output | |
| Args: | |
| image_url: URL or base64-encoded image string (with data:image prefix) | |
| model: Model to use for inference (default: Qwen2.5 VL 72B) | |
| api_key: Fireworks API key (defaults to FIREWORKS_API_KEY env variable) | |
| provider: Provider to use for inference (default: Fireworks) | |
| prompt_style: Prompt style from library (concise, descriptive, explanatory). Defaults to fallback prompts. | |
| Returns: | |
| ProductClassification: Structured classification and description | |
| """ | |
| if provider.lower() in ["fireworks", "fireworksai"]: | |
| client = OpenAI( | |
| api_key=api_key or os.getenv("FIREWORKS_API_KEY"), | |
| base_url="https://api.fireworks.ai/inference/v1", | |
| ) | |
| elif provider.lower() == "openai": | |
| client = OpenAI( | |
| api_key=api_key or os.getenv("OPENAI_API_KEY"), | |
| ) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| # Get prompts from library or use defaults | |
| if prompt_style and prompt_style in PROMPT_LIBRARY: | |
| system_prompt = PROMPT_LIBRARY[prompt_style]["system"] | |
| user_prompt = PROMPT_LIBRARY[prompt_style]["user"] | |
| else: | |
| system_prompt = SYSTEM_PROMPT | |
| user_prompt = USER_PROMPT | |
| # Call the API with structured output | |
| completion = client.beta.chat.completions.parse( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": image_url}}, | |
| {"type": "text", "text": user_prompt}, | |
| ], | |
| }, | |
| ], | |
| response_format=ProductClassification, | |
| ) | |
| # Extract and return the structured output | |
| return completion.choices[0].message.parsed | |
| async def analyze_product_image_async( | |
| image_url: str, | |
| model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct", | |
| api_key: Optional[str] = None, | |
| provider: str = "Fireworks", | |
| prompt_style: Optional[str] = None, | |
| ) -> ProductClassification: | |
| """ | |
| Async version of analyze_product_image for concurrent processing | |
| Args: | |
| image_url: URL or base64-encoded image string (with data:image prefix) | |
| model: Model to use for inference (default: Qwen2.5 VL 72B) | |
| api_key: API key (defaults to provider-specific env variable) | |
| provider: Provider to use for inference (default: Fireworks) | |
| prompt_style: Prompt style from library (concise, descriptive, explanatory). Defaults to fallback prompts. | |
| Returns: | |
| ProductClassification: Structured classification and description | |
| """ | |
| if provider.lower() in ["fireworks", "fireworksai"]: | |
| client = AsyncOpenAI( | |
| api_key=api_key or os.getenv("FIREWORKS_API_KEY"), | |
| base_url="https://api.fireworks.ai/inference/v1", | |
| ) | |
| elif provider.lower() == "openai": | |
| client = AsyncOpenAI( | |
| api_key=api_key or os.getenv("OPENAI_API_KEY"), | |
| ) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| # Get prompts from library or use defaults | |
| if prompt_style and prompt_style in PROMPT_LIBRARY: | |
| system_prompt = PROMPT_LIBRARY[prompt_style]["system"] | |
| user_prompt = PROMPT_LIBRARY[prompt_style]["user"] | |
| else: | |
| system_prompt = SYSTEM_PROMPT | |
| user_prompt = USER_PROMPT | |
| # Call the API with structured output | |
| completion = await client.beta.chat.completions.parse( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": image_url}}, | |
| {"type": "text", "text": user_prompt}, | |
| ], | |
| }, | |
| ], | |
| response_format=ProductClassification, | |
| ) | |
| # Extract and return the structured output | |
| return completion.choices[0].message.parsed | |
| def batch_analyze_products( | |
| image_urls: list[str], | |
| model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct", | |
| api_key: Optional[str] = None, | |
| base_url: str = "https://api.fireworks.ai/inference/v1", | |
| ) -> list[Optional[ProductClassification]]: | |
| """ | |
| Analyze multiple fashion product images | |
| Args: | |
| image_urls: List of image URLs or base64-encoded strings | |
| model: Model to use for inference | |
| api_key: Fireworks API key | |
| base_url: API base URL | |
| Returns: | |
| list[Optional[ProductClassification]]: List of structured classifications (None for failed analyses) | |
| """ | |
| results = [] | |
| for idx, image_url in enumerate(image_urls): | |
| try: | |
| result = analyze_product_image( | |
| image_url=image_url, model=model, api_key=api_key, base_url=base_url | |
| ) | |
| results.append(result) | |
| print(f"Processed image {idx + 1}/{len(image_urls)}") | |
| except Exception as e: | |
| print(f"Error processing image {idx + 1}: {e}") | |
| results.append(None) | |
| return results | |