import gradio as gr
from gradio_client import Client, handle_file
import spaces
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
from datetime import datetime
import shutil
import cv2
from typing import *
import torch
import numpy as np
from PIL import Image
import base64
import io
import tempfile
from trellis2.modules.sparse import SparseTensor
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
import o_voxel
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
MODES = [
{"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
{"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
{"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
{"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
{"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
{"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
]
STEPS = 8
DEFAULT_MODE = 3
DEFAULT_STEP = 3
css = """
/* Overwrite Gradio Default Style */
.stepper-wrapper {
padding: 0;
}
.stepper-container {
padding: 0;
align-items: center;
}
.step-button {
flex-direction: row;
}
.step-connector {
transform: none;
}
.step-number {
width: 16px;
height: 16px;
}
.step-label {
position: relative;
bottom: 0;
}
.wrap.center.full {
inset: 0;
height: 100%;
}
.wrap.center.full.translucent {
background: var(--block-background-fill);
}
.meta-text-center {
display: block !important;
position: absolute !important;
top: unset !important;
bottom: 0 !important;
right: 0 !important;
transform: unset !important;
}
/* Previewer */
.previewer-container {
position: relative;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
width: 100%;
height: 722px;
margin: 0 auto;
padding: 20px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.previewer-container .tips-icon {
position: absolute;
right: 10px;
top: 10px;
z-index: 10;
border-radius: 10px;
color: #fff;
background-color: var(--color-accent);
padding: 3px 6px;
user-select: none;
}
.previewer-container .tips-text {
position: absolute;
right: 10px;
top: 50px;
color: #fff;
background-color: var(--color-accent);
border-radius: 10px;
padding: 6px;
text-align: left;
max-width: 300px;
z-index: 10;
transition: all 0.3s;
opacity: 0%;
user-select: none;
}
.previewer-container .tips-text p {
font-size: 14px;
line-height: 1.2;
}
.tips-icon:hover + .tips-text {
display: block;
opacity: 100%;
}
/* Row 1: Display Modes */
.previewer-container .mode-row {
width: 100%;
display: flex;
gap: 8px;
justify-content: center;
margin-bottom: 20px;
flex-wrap: wrap;
}
.previewer-container .mode-btn {
width: 24px;
height: 24px;
border-radius: 50%;
cursor: pointer;
opacity: 0.5;
transition: all 0.2s;
border: 2px solid #ddd;
object-fit: cover;
}
.previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
.previewer-container .mode-btn.active {
opacity: 1;
border-color: var(--color-accent);
transform: scale(1.1);
}
/* Row 2: Display Image */
.previewer-container .display-row {
margin-bottom: 20px;
min-height: 400px;
width: 100%;
flex-grow: 1;
display: flex;
justify-content: center;
align-items: center;
}
.previewer-container .previewer-main-image {
max-width: 100%;
max-height: 100%;
flex-grow: 1;
object-fit: contain;
display: none;
}
.previewer-container .previewer-main-image.visible {
display: block;
}
/* Row 3: Custom HTML Slider */
.previewer-container .slider-row {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
gap: 10px;
padding: 0 10px;
}
.previewer-container input[type=range] {
-webkit-appearance: none;
width: 100%;
max-width: 400px;
background: transparent;
}
.previewer-container input[type=range]::-webkit-slider-runnable-track {
width: 100%;
height: 8px;
cursor: pointer;
background: #ddd;
border-radius: 5px;
}
.previewer-container input[type=range]::-webkit-slider-thumb {
height: 20px;
width: 20px;
border-radius: 50%;
background: var(--color-accent);
cursor: pointer;
-webkit-appearance: none;
margin-top: -6px;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
transition: transform 0.1s;
}
.previewer-container input[type=range]::-webkit-slider-thumb:hover {
transform: scale(1.2);
}
/* Overwrite Previewer Block Style */
.gradio-container .padded:has(.previewer-container) {
padding: 0 !important;
}
.gradio-container:has(.previewer-container) [data-testid="block-label"] {
position: absolute;
top: 0;
left: 0;
}
"""
head = """
"""
empty_html = f"""
"""
def image_to_base64(image):
buffered = io.BytesIO()
image = image.convert("RGB")
image.save(buffered, format="jpeg", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def remove_background(input: Image.Image) -> Image.Image:
with tempfile.NamedTemporaryFile(suffix='.png') as f:
input = input.convert('RGB')
input.save(f.name)
output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
output = Image.open(output)
return output
def preprocess_image(input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
output = remove_background(input)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
shape_slat, tex_slat, res = latents
return {
'shape_slat_feats': shape_slat.feats.cpu().numpy(),
'tex_slat_feats': tex_slat.feats.cpu().numpy(),
'coords': shape_slat.coords.cpu().numpy(),
'res': res,
}
def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
shape_slat = SparseTensor(
feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
coords=torch.from_numpy(state['coords']).cuda(),
)
tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
return shape_slat, tex_slat, state['res']
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU(duration=120)
def image_to_3d(
image: Image.Image,
seed: int,
resolution: str,
ss_guidance_strength: float,
ss_guidance_rescale: float,
ss_sampling_steps: int,
ss_rescale_t: float,
shape_slat_guidance_strength: float,
shape_slat_guidance_rescale: float,
shape_slat_sampling_steps: int,
shape_slat_rescale_t: float,
tex_slat_guidance_strength: float,
tex_slat_guidance_rescale: float,
tex_slat_sampling_steps: int,
tex_slat_rescale_t: float,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
# --- Sampling ---
outputs, latents = pipeline.run(
image,
seed=seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale,
"rescale_t": ss_rescale_t,
},
shape_slat_sampler_params={
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
},
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
pipeline_type={
"512": "512",
"1024": "1024_cascade",
"1536": "1536_cascade",
}[resolution],
return_latent=True,
)
mesh = outputs[0]
mesh.simplify(16777216) # nvdiffrast limit
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
state = pack_state(latents)
torch.cuda.empty_cache()
# --- HTML Construction ---
# The Stack of 48 Images
images_html = ""
for m_idx, mode in enumerate(MODES):
for s_idx in range(STEPS):
# ID Naming Convention: view-m{mode}-s{step}
unique_id = f"view-m{m_idx}-s{s_idx}"
# Logic: Only Mode 0, Step 0 is visible initially
is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
vis_class = "visible" if is_visible else ""
# Image Source
img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
# Render the Tag
images_html += f"""
"""
# Button Row HTML
btns_html = ""
for idx, mode in enumerate(MODES):
active_class = "active" if idx == DEFAULT_MODE else ""
# Note: onclick calls the JS function defined in Head
btns_html += f"""
"""
# Assemble the full component
full_html = f"""
💡Tips
● Render Mode - Click on the circular buttons to switch between different render modes.
● View Angle - Drag the slider to change the view angle.
{images_html}
{btns_html}
"""
return state, full_html
@spaces.GPU(duration=120)
def extract_glb(
state: dict,
decimation_target: int,
texture_size: int,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
decimation_target (int): The target face count for decimation.
texture_size (int): The texture resolution.
Returns:
str: The path to the extracted GLB file.
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shape_slat, tex_slat, res = unpack_state(state)
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
mesh.simplify(16777216)
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices,
faces=mesh.faces,
attr_volume=mesh.attrs,
coords=mesh.coords,
attr_layout=pipeline.pbr_attr_layout,
grid_size=res,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target,
texture_size=texture_size,
remesh=True,
remesh_band=1,
remesh_project=0,
use_tqdm=True,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
os.makedirs(user_dir, exist_ok=True)
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
glb.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
* Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
* Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
""")
with gr.Row():
with gr.Column(scale=1, min_width=360):
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
generate_btn = gr.Button("Generate")
with gr.Accordion(label="Advanced Settings", open=False):
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
gr.Markdown("Stage 2: Shape Generation")
with gr.Row():
shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
gr.Markdown("Stage 3: Material Generation")
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
with gr.Column(scale=10):
with gr.Walkthrough(selected=0) as walkthrough:
with gr.Step("Preview", id=0):
preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
extract_btn = gr.Button("Extract GLB")
with gr.Step("Extract", id=1):
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
download_btn = gr.DownloadButton(label="Download GLB")
gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*")
with gr.Column(scale=1, min_width=172):
examples = gr.Examples(
examples=[
f'assets/example_image/{image}'
for image in os.listdir("assets/example_image")
],
inputs=[image_prompt],
fn=preprocess_image,
outputs=[image_prompt],
run_on_click=True,
examples_per_page=18,
)
output_buf = gr.State()
# Handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
lambda: gr.Walkthrough(selected=0), outputs=walkthrough
).then(
image_to_3d,
inputs=[
image_prompt, seed, resolution,
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
],
outputs=[output_buf, preview_output],
)
extract_btn.click(
lambda: gr.Walkthrough(selected=1), outputs=walkthrough
).then(
extract_glb,
inputs=[output_buf, decimation_target, texture_size],
outputs=[glb_output, download_btn],
)
# Launch the Gradio app
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
# Construct ui components
btn_img_base64_strs = {}
for i in range(len(MODES)):
icon = Image.open(MODES[i]['icon'])
MODES[i]['icon_base64'] = image_to_base64(icon)
rmbg_client = Client("briaai/BRIA-RMBG-2.0")
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
pipeline.rembg_model = None
pipeline.low_vram = False
pipeline.cuda()
envmap = {
'forest': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'sunset': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'courtyard': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
}
demo.launch(css=css, head=head)