import gradio as gr from gradio_client import Client, handle_file import spaces import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' from datetime import datetime import shutil import cv2 from typing import * import torch import numpy as np from PIL import Image import base64 import io import tempfile from trellis2.modules.sparse import SparseTensor from trellis2.pipelines import Trellis2ImageTo3DPipeline from trellis2.renderers import EnvMap from trellis2.utils import render_utils import o_voxel MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') MODES = [ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, ] STEPS = 8 DEFAULT_MODE = 3 DEFAULT_STEP = 3 css = """ /* Overwrite Gradio Default Style */ .stepper-wrapper { padding: 0; } .stepper-container { padding: 0; align-items: center; } .step-button { flex-direction: row; } .step-connector { transform: none; } .step-number { width: 16px; height: 16px; } .step-label { position: relative; bottom: 0; } .wrap.center.full { inset: 0; height: 100%; } .wrap.center.full.translucent { background: var(--block-background-fill); } .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; } /* Previewer */ .previewer-container { position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; } .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; } .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; } .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; } .tips-icon:hover + .tips-text { display: block; opacity: 100%; } /* Row 1: Display Modes */ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; } .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; } .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); } .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); } /* Row 2: Display Image */ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; } .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; } .previewer-container .previewer-main-image.visible { display: block; } /* Row 3: Custom HTML Slider */ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; } .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; } .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; } .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; } .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); } /* Overwrite Previewer Block Style */ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; } .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; } """ head = """ """ empty_html = f"""
""" def image_to_base64(image): buffered = io.BytesIO() image = image.convert("RGB") image.save(buffered, format="jpeg", quality=85) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(user_dir) def remove_background(input: Image.Image) -> Image.Image: with tempfile.NamedTemporaryFile(suffix='.png') as f: input = input.convert('RGB') input.save(f.name) output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0] output = Image.open(output) return output def preprocess_image(input: Image.Image) -> Image.Image: """ Preprocess the input image. """ # if has alpha channel, use it directly; otherwise, remove background has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) if has_alpha: output = input else: output = remove_background(input) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = output.crop(bbox) # type: ignore output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: shape_slat, tex_slat, res = latents return { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: shape_slat = SparseTensor( feats=torch.from_numpy(state['shape_slat_feats']).cuda(), coords=torch.from_numpy(state['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) return shape_slat, tex_slat, state['res'] def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed @spaces.GPU(duration=120) def image_to_3d( image: Image.Image, seed: int, resolution: str, ss_guidance_strength: float, ss_guidance_rescale: float, ss_sampling_steps: int, ss_rescale_t: float, shape_slat_guidance_strength: float, shape_slat_guidance_rescale: float, shape_slat_sampling_steps: int, shape_slat_rescale_t: float, tex_slat_guidance_strength: float, tex_slat_guidance_rescale: float, tex_slat_sampling_steps: int, tex_slat_rescale_t: float, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> str: # --- Sampling --- outputs, latents = pipeline.run( image, seed=seed, preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t, }, shape_slat_sampler_params={ "steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t, }, tex_slat_sampler_params={ "steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t, }, pipeline_type={ "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", }[resolution], return_latent=True, ) mesh = outputs[0] mesh.simplify(16777216) # nvdiffrast limit images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap) state = pack_state(latents) torch.cuda.empty_cache() # --- HTML Construction --- # The Stack of 48 Images images_html = "" for m_idx, mode in enumerate(MODES): for s_idx in range(STEPS): # ID Naming Convention: view-m{mode}-s{step} unique_id = f"view-m{m_idx}-s{s_idx}" # Logic: Only Mode 0, Step 0 is visible initially is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) vis_class = "visible" if is_visible else "" # Image Source img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx])) # Render the Tag images_html += f""" """ # Button Row HTML btns_html = "" for idx, mode in enumerate(MODES): active_class = "active" if idx == DEFAULT_MODE else "" # Note: onclick calls the JS function defined in Head btns_html += f""" """ # Assemble the full component full_html = f"""
💡Tips

Render Mode - Click on the circular buttons to switch between different render modes.

View Angle - Drag the slider to change the view angle.

{images_html}
{btns_html}
""" return state, full_html @spaces.GPU(duration=120) def extract_glb( state: dict, decimation_target: int, texture_size: int, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> Tuple[str, str]: """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. decimation_target (int): The target face count for decimation. texture_size (int): The texture resolution. Returns: str: The path to the extracted GLB file. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] mesh.simplify(16777216) glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" os.makedirs(user_dir, exist_ok=True) glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') glb.export(glb_path, extension_webp=True) torch.cuda.empty_cache() return glb_path, glb_path with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset. * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time. """) with gr.Row(): with gr.Column(scale=1, min_width=360): image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000) texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) generate_btn = gr.Button("Generate") with gr.Accordion(label="Advanced Settings", open=False): gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("Stage 2: Shape Generation") with gr.Row(): shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("Stage 3: Material Generation") with gr.Row(): tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) with gr.Column(scale=10): with gr.Walkthrough(selected=0) as walkthrough: with gr.Step("Preview", id=0): preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True) extract_btn = gr.Button("Extract GLB") with gr.Step("Extract", id=1): glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) download_btn = gr.DownloadButton(label="Download GLB") gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*") with gr.Column(scale=1, min_width=172): examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt], run_on_click=True, examples_per_page=18, ) output_buf = gr.State() # Handlers demo.load(start_session) demo.unload(end_session) image_prompt.upload( preprocess_image, inputs=[image_prompt], outputs=[image_prompt], ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( lambda: gr.Walkthrough(selected=0), outputs=walkthrough ).then( image_to_3d, inputs=[ image_prompt, seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, ], outputs=[output_buf, preview_output], ) extract_btn.click( lambda: gr.Walkthrough(selected=1), outputs=walkthrough ).then( extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn], ) # Launch the Gradio app if __name__ == "__main__": os.makedirs(TMP_DIR, exist_ok=True) # Construct ui components btn_img_base64_strs = {} for i in range(len(MODES)): icon = Image.open(MODES[i]['icon']) MODES[i]['icon_base64'] = image_to_base64(icon) rmbg_client = Client("briaai/BRIA-RMBG-2.0") pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') pipeline.rembg_model = None pipeline.low_vram = False pipeline.cuda() envmap = { 'forest': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'sunset': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'courtyard': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), } demo.launch(css=css, head=head)