JeffreyXiang commited on
Commit
388d03f
·
1 Parent(s): 917a889
.gitignore CHANGED
@@ -19,7 +19,6 @@ lib64/
19
  parts/
20
  sdist/
21
  var/
22
- wheels/
23
  share/python-wheels/
24
  *.egg-info/
25
  .installed.cfg
 
19
  parts/
20
  sdist/
21
  var/
 
22
  share/python-wheels/
23
  *.egg-info/
24
  .installed.cfg
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: TRELLIS.2
3
  emoji: 🏢
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.1.0
8
  app_file: app.py
 
1
  ---
2
  title: TRELLIS.2
3
  emoji: 🏢
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.1.0
8
  app_file: app.py
requirements.txt CHANGED
@@ -13,8 +13,13 @@ trimesh==4.10.1
13
  transformers==4.46.3
14
  git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
15
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
16
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
17
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
18
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
19
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrast-0.3.5-cp310-cp310-linux_x86_64?download=true
20
- https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
 
 
 
 
 
 
13
  transformers==4.46.3
14
  git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
15
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
16
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
17
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
18
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
19
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrast-0.3.5-cp310-cp310-linux_x86_64.whl
20
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl
21
+ # https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
22
+ # https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
23
+ # https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl?download=true
24
+ # https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrast-0.3.5-cp310-cp310-linux_x86_64?download=true
25
+ # https://huggingface.co/spaces/JeffreyXiang/TRELLIS.2/resolve/main/wheels/nvdiffrec_render-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
trellis2/modules/image_feature_extractor.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from transformers import DINOv3ViTModel
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ class DinoV2FeatureExtractor:
11
+ """
12
+ Feature extractor for DINOv2 models.
13
+ """
14
+ def __init__(self, model_name: str):
15
+ self.model_name = model_name
16
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
17
+ self.model.eval()
18
+ self.transform = transforms.Compose([
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20
+ ])
21
+
22
+ def to(self, device):
23
+ self.model.to(device)
24
+
25
+ def cuda(self):
26
+ self.model.cuda()
27
+
28
+ def cpu(self):
29
+ self.model.cpu()
30
+
31
+ @torch.no_grad()
32
+ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
33
+ """
34
+ Extract features from the image.
35
+
36
+ Args:
37
+ image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
38
+
39
+ Returns:
40
+ A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
41
+ """
42
+ if isinstance(image, torch.Tensor):
43
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
44
+ elif isinstance(image, list):
45
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
46
+ image = [i.resize((518, 518), Image.LANCZOS) for i in image]
47
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
48
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
49
+ image = torch.stack(image).cuda()
50
+ else:
51
+ raise ValueError(f"Unsupported type of image: {type(image)}")
52
+
53
+ image = self.transform(image).cuda()
54
+ features = self.model(image, is_training=True)['x_prenorm']
55
+ patchtokens = F.layer_norm(features, features.shape[-1:])
56
+ return patchtokens
57
+
58
+
59
+ class DinoV3FeatureExtractor:
60
+ """
61
+ Feature extractor for DINOv3 models.
62
+ """
63
+ def __init__(self, model_name: str, image_size=512):
64
+ self.model_name = model_name
65
+ self.model = DINOv3ViTModel.from_pretrained(model_name)
66
+ self.model.eval()
67
+ self.image_size = image_size
68
+ self.transform = transforms.Compose([
69
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
70
+ ])
71
+
72
+ def to(self, device):
73
+ self.model.to(device)
74
+
75
+ def cuda(self):
76
+ self.model.cuda()
77
+
78
+ def cpu(self):
79
+ self.model.cpu()
80
+
81
+ def extract_features(self, image: torch.Tensor) -> torch.Tensor:
82
+ image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
83
+ hidden_states = self.model.embeddings(image, bool_masked_pos=None)
84
+ position_embeddings = self.model.rope_embeddings(image)
85
+
86
+ for i, layer_module in enumerate(self.model.layer):
87
+ hidden_states = layer_module(
88
+ hidden_states,
89
+ position_embeddings=position_embeddings,
90
+ )
91
+
92
+ return F.layer_norm(hidden_states, hidden_states.shape[-1:])
93
+
94
+ @torch.no_grad()
95
+ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
96
+ """
97
+ Extract features from the image.
98
+
99
+ Args:
100
+ image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
101
+
102
+ Returns:
103
+ A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
104
+ """
105
+ if isinstance(image, torch.Tensor):
106
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
107
+ elif isinstance(image, list):
108
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
109
+ image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
110
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
111
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
112
+ image = torch.stack(image).cuda()
113
+ else:
114
+ raise ValueError(f"Unsupported type of image: {type(image)}")
115
+
116
+ image = self.transform(image).cuda()
117
+ features = self.extract_features(image)
118
+ return features
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -5,8 +5,8 @@ import numpy as np
5
  from PIL import Image
6
  from .base import Pipeline
7
  from . import samplers, rembg
8
- from .. import trainers
9
- from ..modules import sparse as sp
10
  from ..representations import Mesh, MeshWithVoxel
11
 
12
 
@@ -24,7 +24,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
24
  tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
25
  shape_slat_normalization (dict): The normalization parameters for the structured latent.
26
  tex_slat_normalization (dict): The normalization parameters for the texture latent.
27
- image_cond_model (trainers.Trainer): The image conditioning model.
28
  rembg_model (Callable): The model for removing background.
29
  low_vram (bool): Whether to use low-VRAM mode.
30
  """
@@ -92,7 +92,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
92
  new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
93
  new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
94
 
95
- new_pipeline.image_cond_model = getattr(trainers, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
96
  new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
97
 
98
  new_pipeline.low_vram = args.get('low_vram', True)
@@ -230,7 +230,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
230
  flow_model,
231
  coords: torch.Tensor,
232
  sampler_params: dict = {},
233
- ) -> sp.SparseTensor:
234
  """
235
  Sample structured latent with the given conditioning.
236
 
@@ -240,7 +240,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
240
  sampler_params (dict): Additional parameters for the sampler.
241
  """
242
  # Sample structured latent
243
- noise = sp.SparseTensor(
244
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
245
  coords=coords,
246
  )
@@ -275,7 +275,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
275
  coords: torch.Tensor,
276
  sampler_params: dict = {},
277
  max_num_tokens: int = 49152,
278
- ) -> sp.SparseTensor:
279
  """
280
  Sample structured latent with the given conditioning.
281
 
@@ -285,7 +285,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
285
  sampler_params (dict): Additional parameters for the sampler.
286
  """
287
  # LR
288
- noise = sp.SparseTensor(
289
  feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
290
  coords=coords,
291
  )
@@ -329,7 +329,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
329
  hr_resolution -= 128
330
 
331
  # Sample structured latent
332
- noise = sp.SparseTensor(
333
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
334
  coords=coords,
335
  )
@@ -355,19 +355,19 @@ class Trellis2ImageTo3DPipeline(Pipeline):
355
 
356
  def decode_shape_slat(
357
  self,
358
- slat: sp.SparseTensor,
359
  resolution: int,
360
- ) -> Tuple[List[Mesh], List[sp.SparseTensor]]:
361
  """
362
  Decode the structured latent.
363
 
364
  Args:
365
- slat (sp.SparseTensor): The structured latent.
366
  formats (List[str]): The formats to decode the structured latent to.
367
 
368
  Returns:
369
  List[Mesh]: The decoded meshes.
370
- List[sp.SparseTensor]: The decoded substructures.
371
  """
372
  self.models['shape_slat_decoder'].set_resolution(resolution)
373
  if self.low_vram:
@@ -383,15 +383,15 @@ class Trellis2ImageTo3DPipeline(Pipeline):
383
  self,
384
  cond: dict,
385
  flow_model,
386
- shape_slat: sp.SparseTensor,
387
  sampler_params: dict = {},
388
- ) -> sp.SparseTensor:
389
  """
390
  Sample structured latent with the given conditioning.
391
 
392
  Args:
393
  cond (dict): The conditioning information.
394
- shape_slat (sp.SparseTensor): The structured latent for shape
395
  sampler_params (dict): Additional parameters for the sampler.
396
  """
397
  # Sample structured latent
@@ -424,18 +424,18 @@ class Trellis2ImageTo3DPipeline(Pipeline):
424
 
425
  def decode_tex_slat(
426
  self,
427
- slat: sp.SparseTensor,
428
- subs: List[sp.SparseTensor],
429
- ) -> sp.SparseTensor:
430
  """
431
  Decode the structured latent.
432
 
433
  Args:
434
- slat (sp.SparseTensor): The structured latent.
435
  formats (List[str]): The formats to decode the structured latent to.
436
 
437
  Returns:
438
- List[sp.SparseTensor]: The decoded texture voxels
439
  """
440
  if self.low_vram:
441
  self.models['tex_slat_decoder'].to(self.device)
@@ -447,16 +447,16 @@ class Trellis2ImageTo3DPipeline(Pipeline):
447
  @torch.no_grad()
448
  def decode_latent(
449
  self,
450
- shape_slat: sp.SparseTensor,
451
- tex_slat: sp.SparseTensor,
452
  resolution: int,
453
  ) -> List[MeshWithVoxel]:
454
  """
455
  Decode the latent codes.
456
 
457
  Args:
458
- shape_slat (sp.SparseTensor): The structured latent for shape.
459
- tex_slat (sp.SparseTensor): The structured latent for texture.
460
  resolution (int): The resolution of the output.
461
  """
462
  meshes, subs = self.decode_shape_slat(shape_slat, resolution)
 
5
  from PIL import Image
6
  from .base import Pipeline
7
  from . import samplers, rembg
8
+ from ..modules.sparse import SparseTensor
9
+ from ..modules import image_feature_extractor
10
  from ..representations import Mesh, MeshWithVoxel
11
 
12
 
 
24
  tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
25
  shape_slat_normalization (dict): The normalization parameters for the structured latent.
26
  tex_slat_normalization (dict): The normalization parameters for the texture latent.
27
+ image_cond_model (Callable): The image conditioning model.
28
  rembg_model (Callable): The model for removing background.
29
  low_vram (bool): Whether to use low-VRAM mode.
30
  """
 
92
  new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
93
  new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
94
 
95
+ new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
96
  new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
97
 
98
  new_pipeline.low_vram = args.get('low_vram', True)
 
230
  flow_model,
231
  coords: torch.Tensor,
232
  sampler_params: dict = {},
233
+ ) -> SparseTensor:
234
  """
235
  Sample structured latent with the given conditioning.
236
 
 
240
  sampler_params (dict): Additional parameters for the sampler.
241
  """
242
  # Sample structured latent
243
+ noise = SparseTensor(
244
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
245
  coords=coords,
246
  )
 
275
  coords: torch.Tensor,
276
  sampler_params: dict = {},
277
  max_num_tokens: int = 49152,
278
+ ) -> SparseTensor:
279
  """
280
  Sample structured latent with the given conditioning.
281
 
 
285
  sampler_params (dict): Additional parameters for the sampler.
286
  """
287
  # LR
288
+ noise = SparseTensor(
289
  feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
290
  coords=coords,
291
  )
 
329
  hr_resolution -= 128
330
 
331
  # Sample structured latent
332
+ noise = SparseTensor(
333
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
334
  coords=coords,
335
  )
 
355
 
356
  def decode_shape_slat(
357
  self,
358
+ slat: SparseTensor,
359
  resolution: int,
360
+ ) -> Tuple[List[Mesh], List[SparseTensor]]:
361
  """
362
  Decode the structured latent.
363
 
364
  Args:
365
+ slat (SparseTensor): The structured latent.
366
  formats (List[str]): The formats to decode the structured latent to.
367
 
368
  Returns:
369
  List[Mesh]: The decoded meshes.
370
+ List[SparseTensor]: The decoded substructures.
371
  """
372
  self.models['shape_slat_decoder'].set_resolution(resolution)
373
  if self.low_vram:
 
383
  self,
384
  cond: dict,
385
  flow_model,
386
+ shape_slat: SparseTensor,
387
  sampler_params: dict = {},
388
+ ) -> SparseTensor:
389
  """
390
  Sample structured latent with the given conditioning.
391
 
392
  Args:
393
  cond (dict): The conditioning information.
394
+ shape_slat (SparseTensor): The structured latent for shape
395
  sampler_params (dict): Additional parameters for the sampler.
396
  """
397
  # Sample structured latent
 
424
 
425
  def decode_tex_slat(
426
  self,
427
+ slat: SparseTensor,
428
+ subs: List[SparseTensor],
429
+ ) -> SparseTensor:
430
  """
431
  Decode the structured latent.
432
 
433
  Args:
434
+ slat (SparseTensor): The structured latent.
435
  formats (List[str]): The formats to decode the structured latent to.
436
 
437
  Returns:
438
+ List[SparseTensor]: The decoded texture voxels
439
  """
440
  if self.low_vram:
441
  self.models['tex_slat_decoder'].to(self.device)
 
447
  @torch.no_grad()
448
  def decode_latent(
449
  self,
450
+ shape_slat: SparseTensor,
451
+ tex_slat: SparseTensor,
452
  resolution: int,
453
  ) -> List[MeshWithVoxel]:
454
  """
455
  Decode the latent codes.
456
 
457
  Args:
458
+ shape_slat (SparseTensor): The structured latent for shape.
459
+ tex_slat (SparseTensor): The structured latent for texture.
460
  resolution (int): The resolution of the output.
461
  """
462
  meshes, subs = self.decode_shape_slat(shape_slat, resolution)