Spaces:
Running
Running
| import subprocess | |
| import shlex | |
| import sys | |
| import os | |
| import tempfile | |
| import numpy as np | |
| import io | |
| import base64 | |
| import json | |
| import uvicorn | |
| import torch | |
| from PIL import Image | |
| # Install the custom component if needed | |
| subprocess.run( | |
| shlex.split( | |
| "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl" | |
| ) | |
| ) | |
| import gradio as gr | |
| from fastapi import FastAPI, Request | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from gradio_client import Client, handle_file | |
| from gradio_magicquillv2 import MagicQuillV2 | |
| from util import ( | |
| read_base64_image as read_base64_image_utils, | |
| tensor_to_base64, | |
| get_mask_bbox | |
| ) | |
| # --- Configuration --- | |
| # Set this to the URL of your backend Space (running app_backend.py) | |
| BACKEND_URL = "LiuZichen/MagicQuillV2" | |
| SAM_URL = "LiuZichen/MagicQuillHelper" | |
| print(f"Target Backend URL: {BACKEND_URL}") | |
| # We still initialize SAM client globally as it might not require ZeroGPU quotas | |
| # or is a helper CPU space. | |
| print(f"Connecting to SAM client at: {SAM_URL}") | |
| try: | |
| sam_client = Client(SAM_URL) | |
| except Exception as e: | |
| print(f"Failed to connect to SAM client: {e}") | |
| sam_client = None | |
| def get_zerogpu_headers(request_headers): | |
| """ | |
| Extracts ZeroGPU specific headers from the incoming request headers. | |
| These are required to forward the user's quota token to the backend. | |
| """ | |
| headers = {} | |
| if request_headers: | |
| # These are the headers HF injects for ZeroGPU authentication and tracking | |
| target_headers = [ | |
| "x-ip-token", | |
| "x-zerogpu-token", | |
| "x-zerogpu-uuid", | |
| "authorization", | |
| "cookie" | |
| ] | |
| for h in target_headers: | |
| val = request_headers.get(h) | |
| if val: | |
| headers[h] = val | |
| return headers | |
| def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg, request: gr.Request): | |
| """ | |
| Handler for the Gradio UI. | |
| Note the 'request: gr.Request' argument - Gradio automatically injects this. | |
| """ | |
| merged_image = x['from_frontend']['img'] | |
| total_mask = x['from_frontend']['total_mask'] | |
| original_image = x['from_frontend']['original_image'] | |
| add_color_image = x['from_frontend']['add_color_image'] | |
| add_edge_mask = x['from_frontend']['add_edge_mask'] | |
| remove_edge_mask = x['from_frontend']['remove_edge_mask'] | |
| fill_mask = x['from_frontend']['fill_mask'] | |
| add_prop_image = x['from_frontend']['add_prop_image'] | |
| positive_prompt = x['from_backend']['prompt'] | |
| forward_headers = get_zerogpu_headers(request.headers) | |
| print(f"Debug: Received headers keys: {list(request.headers.keys())}") | |
| print(forward_headers) | |
| try: | |
| # 2. Instantiate a client specifically for this request with the forwarded headers. | |
| # This ensures the backend sees the 'x-zerogpu-token' of the user, not the server. | |
| # gradio_client caches schemas, so re-init is relatively cheap but necessary for headers. | |
| client = Client(BACKEND_URL, headers=forward_headers) | |
| # Call the backend API | |
| res_base64 = client.predict( | |
| merged_image, # merged_image | |
| total_mask, # total_mask | |
| original_image, # original_image | |
| add_color_image, # add_color_image | |
| add_edge_mask, # add_edge_mask | |
| remove_edge_mask, # remove_edge_mask | |
| fill_mask, # fill_mask | |
| add_prop_image, # add_prop_image | |
| positive_prompt, # positive_prompt | |
| negative_prompt, # negative_prompt | |
| fine_edge, # fine_edge | |
| fix_perspective, # fix_perspective | |
| grow_size, # grow_size | |
| edge_strength, # edge_strength | |
| color_strength, # color_strength | |
| local_strength, # local_strength | |
| seed, # seed | |
| steps, # steps | |
| cfg, # cfg | |
| api_name="/generate" | |
| ) | |
| x["from_backend"]["generated_image"] = res_base64 | |
| except Exception as e: | |
| print(f"Error in generation: {e}") | |
| x["from_backend"]["generated_image"] = None | |
| return x | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="MagicQuill V2") as demo: | |
| with gr.Row(elem_classes="row"): | |
| text = gr.Markdown( | |
| """ | |
| # Welcome to MagicQuill V2! Give us a [GitHub star](https://github.com/zliucz/magicquillv2) if you are interested. | |
| Click the [link](https://magicquill.art/v2) to view our demo and tutorial. The paper is on [ArXiv](https://arxiv.org/abs/2512.03046) now. The [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu) quota is 4 minutes per day for normal users and 25 minutes per day for pro users. | |
| """) | |
| with gr.Row(): | |
| ms = MagicQuillV2() | |
| with gr.Row(): | |
| with gr.Column(): | |
| btn = gr.Button("Run", variant="primary") | |
| with gr.Column(): | |
| with gr.Accordion("parameters", open=False): | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value="", interactive=True) | |
| fine_edge = gr.Radio(label="Fine Edge", choices=['enable', 'disable'], value='disable', interactive=True) | |
| fix_perspective = gr.Radio(label="Fix Perspective", choices=['enable', 'disable'], value='disable', interactive=True) | |
| grow_size = gr.Slider(label="Grow Size", minimum=10, maximum=100, value=50, step=1, interactive=True) | |
| edge_strength = gr.Slider(label="Edge Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True) | |
| color_strength = gr.Slider(label="Color Strength", minimum=0.0, maximum=5.0, value=1.5, step=0.01, interactive=True) | |
| local_strength = gr.Slider(label="Local Strength", minimum=0.0, maximum=5.0, value=1.0, step=0.01, interactive=True) | |
| seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True) | |
| steps = gr.Slider(label="Steps", minimum=0, maximum=50, value=20, interactive=True) | |
| cfg = gr.Slider(label="CFG", minimum=0.0, maximum=20.0, value=3.5, step=0.1, interactive=True) | |
| btn.click( | |
| generate_image_handler, | |
| # Note: We do NOT need to explicitly add 'request' to inputs here. | |
| # Gradio handles type hinting for gr.Request automatically. | |
| inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], | |
| outputs=ms | |
| ) | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_root_url(request: Request, route_path: str, root_path: str | None): | |
| return root_path | |
| gr.route_utils.get_root_url = get_root_url | |
| async def process_background_img(request: Request): | |
| img = await request.json() | |
| from util import process_background | |
| # process_background returns tensor [1, H, W, 3] in uint8 or float | |
| resized_img_tensor = process_background(img) | |
| # tensor_to_base64 from util expects tensor | |
| resized_img_base64 = "data:image/webp;base64," + tensor_to_base64( | |
| resized_img_tensor, | |
| quality=80, | |
| method=6 | |
| ) | |
| return resized_img_base64 | |
| async def segmentation(request: Request): | |
| json_data = await request.json() | |
| image_base64 = json_data.get("image", None) | |
| coordinates_positive = json_data.get("coordinates_positive", None) | |
| coordinates_negative = json_data.get("coordinates_negative", None) | |
| bboxes = json_data.get("bboxes", None) | |
| if sam_client is None: | |
| return {"error": "sam client not initialized"} | |
| # Process coordinates and bboxes | |
| pos_coordinates = None | |
| if coordinates_positive and len(coordinates_positive) > 0: | |
| pos_coordinates = [] | |
| for coord in coordinates_positive: | |
| coord['x'] = int(round(coord['x'])) | |
| coord['y'] = int(round(coord['y'])) | |
| pos_coordinates.append({'x': coord['x'], 'y': coord['y']}) | |
| pos_coordinates = json.dumps(pos_coordinates) | |
| neg_coordinates = None | |
| if coordinates_negative and len(coordinates_negative) > 0: | |
| neg_coordinates = [] | |
| for coord in coordinates_negative: | |
| coord['x'] = int(round(coord['x'])) | |
| coord['y'] = int(round(coord['y'])) | |
| neg_coordinates.append({'x': coord['x'], 'y': coord['y']}) | |
| neg_coordinates = json.dumps(neg_coordinates) | |
| bboxes_xyxy = None | |
| if bboxes and len(bboxes) > 0: | |
| valid_bboxes = [] | |
| for bbox in bboxes: | |
| if (bbox.get("startX") is None or | |
| bbox.get("startY") is None or | |
| bbox.get("endX") is None or | |
| bbox.get("endY") is None): | |
| continue | |
| else: | |
| x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0) | |
| y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0) | |
| x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"]) | |
| y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"]) | |
| valid_bboxes.append((x_min, y_min, x_max, y_max)) | |
| bboxes_xyxy = [] | |
| for bbox in valid_bboxes: | |
| x_min, y_min, x_max, y_max = bbox | |
| bboxes_xyxy.append((x_min, y_min, x_max, y_max)) | |
| if bboxes_xyxy: | |
| bboxes_xyxy = json.dumps(bboxes_xyxy) | |
| print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}") | |
| try: | |
| # Save base64 image to temp file | |
| image_bytes = read_base64_image_utils(image_base64) | |
| pil_image = Image.open(image_bytes) | |
| # Resize for faster transmission (short side 512) | |
| original_size = pil_image.size | |
| w, h = original_size | |
| scale = 512 / min(w, h) | |
| if scale < 1: | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| pil_image_resized = pil_image.resize((new_w, new_h), Image.LANCZOS) | |
| print(f"Resized image for segmentation: {original_size} -> {(new_w, new_h)}") | |
| # Adjust coordinates and bboxes according to scale | |
| if pos_coordinates: | |
| pos_coords_list = json.loads(pos_coordinates) | |
| for coord in pos_coords_list: | |
| coord['x'] = int(coord['x'] * scale) | |
| coord['y'] = int(coord['y'] * scale) | |
| pos_coordinates = json.dumps(pos_coords_list) | |
| if neg_coordinates: | |
| neg_coords_list = json.loads(neg_coordinates) | |
| for coord in neg_coords_list: | |
| coord['x'] = int(coord['x'] * scale) | |
| coord['y'] = int(coord['y'] * scale) | |
| neg_coordinates = json.dumps(neg_coords_list) | |
| if bboxes_xyxy: | |
| bboxes_list = json.loads(bboxes_xyxy) | |
| new_bboxes = [] | |
| for bbox in bboxes_list: | |
| new_bboxes.append(( | |
| int(bbox[0] * scale), | |
| int(bbox[1] * scale), | |
| int(bbox[2] * scale), | |
| int(bbox[3] * scale) | |
| )) | |
| bboxes_xyxy = json.dumps(new_bboxes) | |
| else: | |
| pil_image_resized = pil_image | |
| scale = 1.0 | |
| with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in: | |
| pil_image_resized.save(temp_in.name, format="WEBP", quality=80) | |
| temp_in_path = temp_in.name | |
| # Execute segmentation via Client | |
| result_path = await run_in_threadpool( | |
| sam_client.predict, | |
| handle_file(temp_in_path), | |
| pos_coordinates, | |
| neg_coordinates, | |
| bboxes_xyxy, | |
| api_name="/segment" | |
| ) | |
| os.unlink(temp_in_path) | |
| if isinstance(result_path, (list, tuple)): | |
| result_path = result_path[0] | |
| if not result_path or not os.path.exists(result_path): | |
| raise RuntimeError("Client returned invalid result path") | |
| mask_pil = Image.open(result_path) | |
| if mask_pil.mode != 'L': | |
| mask_pil = mask_pil.convert('L') | |
| pil_image = pil_image.convert("RGB") | |
| if pil_image.size != mask_pil.size: | |
| mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST) | |
| r, g, b = pil_image.split() | |
| res_pil = Image.merge("RGBA", (r, g, b, mask_pil)) | |
| mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0) | |
| mask_bbox = get_mask_bbox(mask_tensor) | |
| if mask_bbox: | |
| x_min, y_min, x_max, y_max = mask_bbox | |
| seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} | |
| else: | |
| seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} | |
| buffered = io.BytesIO() | |
| res_pil.save(buffered, format="PNG") | |
| image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return { | |
| "error": False, | |
| "segmentation_image": "data:image/png;base64," + image_base64_res, | |
| "segmentation_bbox": seg_bbox | |
| } | |
| except Exception as e: | |
| print(f"Error in segmentation: {e}") | |
| return {"error": str(e)} | |
| # Mount the Gradio app | |
| # Reduce concurrency for ZeroGPU to prevent rate limiting | |
| demo.queue(default_concurrency_limit=10, max_size=20) | |
| app = gr.mount_gradio_app(app, demo, path="/", root_path="/demo") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |