import gradio as gr import pandas as pd from pathlib import Path from typing import Optional import os from datasets import load_dataset from PIL import Image import io from dotenv import load_dotenv from src.modules.vlm_inference import analyze_product_image from src.modules.data_processing import pil_to_base64 from src.modules.evals import run_inference_on_dataframe from src.modules.viz import ( load_evaluation_data, create_accuracy_plot, create_precision_recall_plot, ) load_dotenv() # Constants AVAILABLE_MODELS = { "Qwen3-VL-30B-A3B": "accounts/fireworks/models/qwen3-vl-30b-a3b-instruct", "Qwen3-VL-235B-A22B": "accounts/fireworks/models/qwen3-vl-235b-a22b-instruct", "Qwen2.5-VL-32B": "accounts/fireworks/models/qwen2p5-vl-32b-instruct", } MAX_CONCURRENT_REQUESTS = 10 FILE_PATH = Path(__file__).parents[1] ASSETS_PATH = FILE_PATH / "assets" DATA_PATH = FILE_PATH / "data" _NOTEBOOK_PATH = "https://huggingface.co/spaces/fireworks-ai/catalog-extract/blob/main/notebooks/01-eda-and-fine-tuning.ipynb" # Prompt style display names PROMPT_STYLES = { "Data Management": "concise", "Website/Sales": "descriptive", "Customer Support": "explanatory", } def analyze_single_image( image_input, model_name: str, prompt_style_display: Optional[str] = None, ) -> tuple[str, str, str, str]: """ Process a single product image and return classification results Args: image_input: PIL Image or file path model_name: Selected model name api_key: Optional API key override prompt_style_display: Display name for prompt style (e.g., "Data Management") Returns: tuple: (master_category, gender, sub_category, description) """ if image_input is None: return "No image provided", "", "", "" try: img_b64 = pil_to_base64(image_input) model_id = AVAILABLE_MODELS[model_name] api_key = os.getenv("FIREWORKS_API_KEY") prompt_style = ( PROMPT_STYLES.get(prompt_style_display) if prompt_style_display else None ) result = analyze_product_image( image_url=img_b64, model=model_id, api_key=api_key, provider="Fireworks", prompt_style=prompt_style, ) # Format results master_cat = result.master_category gender = result.gender sub_cat = result.sub_category description = result.description return master_cat, gender, sub_cat, description except Exception as e: error_msg = f"Error: {str(e)}" return error_msg, "", "", "" def process_batch_dataset( csv_file, model_name: str, api_key: Optional[str] = None, max_concurrent: int = MAX_CONCURRENT_REQUESTS, ) -> tuple[Optional[pd.DataFrame], str]: """ Process uploaded CSV dataset with product images Args: csv_file: Uploaded CSV file with image data model_name: Selected model name api_key: Optional API key override max_concurrent: Max concurrent API requests Returns: tuple: (results_dataframe, summary_statistics) """ if csv_file is None: return None, "No dataset uploaded" try: # Load dataset df = pd.read_csv(csv_file.name) # Validate required columns required_cols = ["id", "image"] if not all(col in df.columns for col in required_cols): return None, f"Dataset must contain columns: {required_cols}" # Determine provider model_id = AVAILABLE_MODELS[model_name] # Get API key if api_key is None: api_key = os.getenv("FIREWORKS_API_KEY") # Run batch inference results_df = run_inference_on_dataframe( df=df, model=model_id, api_key=api_key, provider="Fireworks", max_concurrent_requests=max_concurrent, ) # Generate summary statistics total_processed = len(results_df) successful = results_df["pred_masterCategory"].notna().sum() failed = total_processed - successful summary = f""" Batch Processing Complete: - Total images: {total_processed} - Successfully classified: {successful} - Failed: {failed} - Success rate: {(successful / total_processed) * 100:.1f}% """ return results_df, summary except Exception as e: return None, f"Error processing dataset: {str(e)}" def load_example_data() -> pd.DataFrame: """Load example product images from HuggingFace dataset""" # Load dataset from HuggingFace ds = load_dataset("ceyda/fashion-products-small") df = ds["train"].to_pandas() # Select 20 random samples sample_df = df.sample(n=20, random_state=42).reset_index(drop=True) # Keep only relevant columns for display display_df = sample_df[["id", "masterCategory", "gender", "subCategory"]].copy() display_df["image_data"] = sample_df["image"] return display_df def get_image_from_row(examples_df: pd.DataFrame, evt: gr.SelectData) -> Image.Image: """Get PIL Image from selected row in examples table""" if evt.index is None or len(evt.index) == 0: return None row_idx = evt.index[0] if row_idx >= len(examples_df): return None # Get the image data from the stored row image_data = examples_df.iloc[row_idx]["image_data"] # Convert to PIL Image if it's a dict (from HuggingFace datasets) if isinstance(image_data, dict): if "bytes" in image_data: return Image.open(io.BytesIO(image_data["bytes"])) elif "path" in image_data: return Image.open(image_data["path"]) # Return as-is if already a PIL Image return image_data def create_demo_interface(): """ Create the Gradio interface with custom theme and layout """ # Load example data at startup example_data = load_example_data() with gr.Blocks( title="Product Catalog Cleansing", theme=gr.themes.Soft(), ) as demo: # Store examples dataframe in state examples_state = gr.State(value=example_data) # Header gr.Markdown( """ # Product Catalog Cleansing Automate product classification, attribute extraction, and catalog enrichment using state-of-the-art multimodal AI. Fine-tuned SOTA OSS models on FireworksAI. """ ) # Model Selection (shared across tabs) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Powered by") gr.Image( value=str(ASSETS_PATH / "fireworks_logo.png"), height=60, width=200, show_label=False, show_download_button=False, container=False, show_fullscreen_button=False, show_share_button=False, ) model_selector = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value=list(AVAILABLE_MODELS.keys())[0], label="Select Model", ) prompt_selector = gr.Dropdown( choices=list(PROMPT_STYLES.keys()), value="Website/Sales", label="Description Style", ) with gr.Tabs(): with gr.TabItem("📸 Image Analysis 📸 "): gr.Markdown("### Upload a product image or select from table below") with gr.Row(): # Left column - Input with gr.Column(scale=1): image_input = gr.Image( label="Upload Product Image", type="pil", height=400 ) analyze_btn = gr.Button( "🔍 Analyze Product", variant="primary", size="lg" ) # Right column - Results with gr.Column(scale=1): gr.Markdown("### Classification Results") master_category_output = gr.Textbox( label="Master Category", interactive=False ) gender_output = gr.Textbox(label="Gender", interactive=False) subcategory_output = gr.Textbox( label="Sub-Category", interactive=False ) description_output = gr.Textbox( label="AI-Generated Description", interactive=False, lines=4 ) # Example Products Table gr.Markdown("### 📚 Example Products (Click a row to load image)") examples_table = gr.Dataframe( value=example_data[ ["id", "masterCategory", "gender", "subCategory"] ], label="Click on any of the rows below to load the product image for analysis", interactive=False, wrap=True, ) # Wire up single image analysis analyze_btn.click( fn=analyze_single_image, inputs=[ image_input, model_selector, prompt_selector, ], outputs=[ master_category_output, gender_output, subcategory_output, description_output, ], ) # Allow clicking table row to load image examples_table.select( fn=get_image_from_row, inputs=[examples_state], outputs=[image_input], ) # Tab 3: Model Evaluation (interactive charts) with gr.TabItem("📈 Model Performance"): gr.Markdown( """ ### Evaluation Results on Fashion Product Dataset Model fine tuned on over 14k images and tested on a validation set of 1000 images. Images pulled from [HuggingFace Datasets](https://huggingface.co/datasets/ceyda/fashion-products-small) """ ) eval_df = load_evaluation_data(DATA_PATH) if eval_df is not None: all_models = eval_df["model"].unique().tolist() all_categories = eval_df["category"].unique().tolist() with gr.Row(): model_filter = gr.CheckboxGroup( choices=all_models, value=all_models, label="Select Models to Display", interactive=True, ) category_filter = gr.CheckboxGroup( choices=all_categories, value=all_categories, label="Select Categories to Display", interactive=True, ) with gr.Row(): accuracy_plot = gr.Plot() with gr.Row(): precision_recall_plot = gr.Plot() def update_plots(selected_models, selected_categories): acc_fig = create_accuracy_plot( eval_df, selected_models, selected_categories ) pr_fig = create_precision_recall_plot( eval_df, selected_models, selected_categories ) return acc_fig, pr_fig model_filter.change( fn=update_plots, inputs=[model_filter, category_filter], outputs=[accuracy_plot, precision_recall_plot], ) category_filter.change( fn=update_plots, inputs=[model_filter, category_filter], outputs=[accuracy_plot, precision_recall_plot], ) demo.load( fn=update_plots, inputs=[model_filter, category_filter], outputs=[accuracy_plot, precision_recall_plot], ) else: gr.Markdown( "⚠️ Evaluation data not found. Please run `python generate_eval_results.py` first." ) with gr.Row(): with gr.Column(): gr.Markdown( """ **Key Findings:** - Qwen2.5-VL-72B-SFT achieves >95% accuracy on masterCategory - Fine-tuned model shows 18% improvement on subCategory vs base model - All models maintain >90% precision and recall on gender classification """ ) with gr.Column(): gr.Markdown( f""" ### 🚀 Want to Fine-Tune Your Own Model? 🚀 Check out our comprehensive fine-tuning cookbook to learn how we built this model! The notebook covers: - Dataset preparation and exploration - Data preprocessing for VLM fine-tuning - Fine-tuning with Fireworks AI - Model evaluation and deployment **[📓 Open Fine-Tuning Notebook]({_NOTEBOOK_PATH})** Perfect for adapting this approach to your own product catalog or domain-specific use case. """ ) return demo if __name__ == "__main__": # Launch demo demo = create_demo_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, )