Spaces:
Sleeping
Sleeping
| # app.py | |
| # FINAL, DEFINITIVE VERSION | |
| # Corrects all model loading paths and relies on the stable requirements.txt | |
| import os | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| from ultralytics import YOLO | |
| from transformers import AutoProcessor, AutoModelForCausalLM, set_seed | |
| from paddleocr import PaddleOCR | |
| import supervision as sv | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # --- Global Configuration --- | |
| REPO_ID = "microsoft/OmniParser-v2.0" | |
| # CORRECTED file paths as they exist in the Hugging Face repository | |
| DETECTION_MODEL_FILENAME = "icon_detect/model.pt" | |
| CAPTION_MODEL_SUBFOLDER = "icon_caption" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"INFO: Using device: {DEVICE}") | |
| set_seed(42) | |
| # --- Model Loading --- | |
| def load_detection_model(): | |
| print("INFO: Loading detection model...") | |
| try: | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=DETECTION_MODEL_FILENAME) | |
| model = YOLO(model_path) | |
| print("INFO: Detection model loaded successfully.") | |
| return model | |
| except Exception as e: | |
| print(f"ERROR: Failed to load detection model: {e}") | |
| return None | |
| def load_caption_model(): | |
| print("INFO: Loading captioning model...") | |
| try: | |
| # CORRECTED loading logic using repo_id and subfolder arguments | |
| model = AutoModelForCausalLM.from_pretrained( | |
| REPO_ID, | |
| subfolder=CAPTION_MODEL_SUBFOLDER, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| attn_implementation="eager" | |
| ).to(DEVICE) | |
| processor = AutoProcessor.from_pretrained( | |
| REPO_ID, | |
| subfolder=CAPTION_MODEL_SUBFOLDER, | |
| trust_remote_code=True | |
| ) | |
| print("INFO: Captioning model loaded successfully.") | |
| return model, processor | |
| except Exception as e: | |
| print(f"ERROR: Failed to load captioning model: {e}") | |
| return None, None | |
| def load_ocr_model(): | |
| print("INFO: Loading OCR model...") | |
| try: | |
| ocr_model = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=(DEVICE == "cuda"), show_log=False) | |
| print("INFO: OCR model loaded successfully.") | |
| return ocr_model | |
| except Exception as e: | |
| print(f"ERROR: Failed to load OCR model: {e}") | |
| return None | |
| detection_model = load_detection_model() | |
| caption_model, caption_processor = load_caption_model() | |
| ocr_model = load_ocr_model() | |
| # --- Inference Pipeline --- | |
| def run_captioning(image, text, model, processor): | |
| prompt = f"<OD> <ref> {text} </ref>" | |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| do_sample=False | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False) | |
| parsed_text = processor.post_process_generation(generated_text, task="<OD>", image_size=image.size) | |
| final_caption_list = parsed_text.get('<OD>', {}).get('labels', []) | |
| return final_caption_list[0] if final_caption_list else "No description available" | |
| def is_box_contained(outer_box, inner_box): | |
| return (outer_box[0] <= inner_box[0] and | |
| outer_box[1] <= inner_box[1] and | |
| outer_box[2] >= inner_box[2] and | |
| outer_box[3] >= inner_box[3]) | |
| def predict(input_image: Image.Image): | |
| if not all([detection_model, caption_model, ocr_model]): | |
| error_messages = [] | |
| if not detection_model: error_messages.append("Detection model failed.") | |
| if not caption_model: error_messages.append("Captioning model failed.") | |
| if not ocr_model: error_messages.append("OCR model failed.") | |
| return {"error": " ".join(error_messages) + " Check container logs for details."} | |
| image_np = np.array(input_image.convert("RGB")) | |
| ocr_results = ocr_model.ocr(image_np, cls=True)[0] | |
| ocr_texts = [] | |
| if ocr_results: | |
| for line in ocr_results: | |
| points, (text, conf) = line | |
| x_coords = [p[0] for p in points] | |
| y_coords = [p[1] for p in points] | |
| ocr_texts.append({"box": [min(x_coords), min(y_coords), max(x_coords), max(y_coords)], "text": text, "conf": conf}) | |
| detection_results = detection_model(image_np, verbose=False)[0] | |
| detections = sv.Detections.from_ultralytics(detection_results) | |
| parsed_elements = [] | |
| element_id_counter = 0 | |
| for i in range(len(detections)): | |
| box = detections.xyxy[i].astype(int) | |
| confidence = detections.confidence[i] | |
| class_name = detection_model.model.names[detections.class_id[i]] | |
| cropped_image = input_image.crop(tuple(box)) | |
| caption = run_captioning(cropped_image, f"Describe this UI element.", caption_model, caption_processor) | |
| contained_text = " ".join([o["text"] for o in ocr_texts if is_box_contained(box.tolist(), o["box"])]) | |
| parsed_elements.append({ | |
| "id": element_id_counter, "box_2d": box.tolist(), "type": class_name, | |
| "text": contained_text.strip(), "description": caption, "confidence": float(confidence) | |
| }) | |
| element_id_counter += 1 | |
| for ocr in ocr_texts: | |
| if not any(is_box_contained(el["box_2d"], ocr["box"]) for el in parsed_elements): | |
| parsed_elements.append({ | |
| "id": element_id_counter, "box_2d": [int(p) for p in ocr["box"]], "type": "text_label", | |
| "text": ocr["text"], "description": "A text label.", "confidence": float(ocr["conf"]) | |
| }) | |
| element_id_counter += 1 | |
| return {"parsed_elements": parsed_elements} | |
| # --- Gradio Interface --- | |
| with gr.Blocks(css="footer {display: none!important}") as demo: | |
| gr.Markdown("# Microsoft OmniParser-v2 API Endpoint\nUpload a UI screenshot to get a parsed JSON output.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Input UI Screenshot") | |
| json_output = gr.JSON(label="Parsed UI Elements") | |
| submit_button = gr.Button("Parse UI", variant="primary") | |
| submit_button.click(fn=predict, inputs=[image_input], outputs=[json_output], api_name="predict") | |
| demo.launch() |