|
|
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py |
|
|
""" |
|
|
import base64 |
|
|
import gc |
|
|
import json |
|
|
import os |
|
|
import hashlib |
|
|
import random |
|
|
from datetime import datetime |
|
|
from glob import glob |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pkg_resources |
|
|
import requests |
|
|
import torch |
|
|
from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler, |
|
|
DPMSolverMultistepScheduler, |
|
|
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, |
|
|
FlowMatchEulerDiscreteScheduler, PNDMScheduler) |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
from safetensors import safe_open |
|
|
|
|
|
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio |
|
|
from ..utils.utils import save_videos_grid |
|
|
from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from ..dist import set_multi_gpus_devices |
|
|
|
|
|
gradio_version = pkg_resources.get_distribution("gradio").version |
|
|
gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False |
|
|
|
|
|
css = """ |
|
|
.toolbutton { |
|
|
margin-buttom: 0em 0em 0em 0em; |
|
|
max-width: 2.5em; |
|
|
min-width: 2.5em !important; |
|
|
height: 2.5em; |
|
|
} |
|
|
""" |
|
|
|
|
|
ddpm_scheduler_dict = { |
|
|
"Euler": EulerDiscreteScheduler, |
|
|
"Euler A": EulerAncestralDiscreteScheduler, |
|
|
"DPM++": DPMSolverMultistepScheduler, |
|
|
"PNDM": PNDMScheduler, |
|
|
"DDIM": DDIMScheduler, |
|
|
"DDIM_Origin": DDIMScheduler, |
|
|
"DDIM_Cog": CogVideoXDDIMScheduler, |
|
|
} |
|
|
flow_scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
} |
|
|
all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} |
|
|
|
|
|
class Fun_Controller: |
|
|
def __init__( |
|
|
self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", |
|
|
config_path=None, ulysses_degree=1, ring_degree=1, |
|
|
fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, |
|
|
weight_dtype=None, savedir_sample=None, |
|
|
): |
|
|
|
|
|
self.basedir = os.getcwd() |
|
|
self.config_dir = os.path.join(self.basedir, "config") |
|
|
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") |
|
|
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
|
|
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") |
|
|
if savedir_sample is None: |
|
|
self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
|
|
else: |
|
|
self.savedir_sample = savedir_sample |
|
|
os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
|
|
|
self.GPU_memory_mode = GPU_memory_mode |
|
|
self.model_name = model_name |
|
|
self.diffusion_transformer_dropdown = model_name |
|
|
self.scheduler_dict = scheduler_dict |
|
|
self.model_type = model_type |
|
|
if config_path is not None: |
|
|
self.config_path = os.path.realpath(config_path) |
|
|
self.config = OmegaConf.load(config_path) |
|
|
else: |
|
|
self.config_path = None |
|
|
self.ulysses_degree = ulysses_degree |
|
|
self.ring_degree = ring_degree |
|
|
self.fsdp_dit = fsdp_dit |
|
|
self.fsdp_text_encoder = fsdp_text_encoder |
|
|
self.compile_dit = compile_dit |
|
|
self.weight_dtype = weight_dtype |
|
|
self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree) |
|
|
|
|
|
self.diffusion_transformer_list = [] |
|
|
self.motion_module_list = [] |
|
|
self.personalized_model_list = [] |
|
|
self.config_list = [] |
|
|
|
|
|
|
|
|
self.tokenizer = None |
|
|
self.text_encoder = None |
|
|
self.vae = None |
|
|
self.transformer = None |
|
|
self.transformer_2 = None |
|
|
self.pipeline = None |
|
|
self.base_model_path = "none" |
|
|
self.base_model_2_path = "none" |
|
|
self.lora_model_path = "none" |
|
|
self.lora_model_2_path = "none" |
|
|
|
|
|
self.refresh_config() |
|
|
self.refresh_diffusion_transformer() |
|
|
self.refresh_personalized_model() |
|
|
if model_name != None: |
|
|
self.update_diffusion_transformer(model_name) |
|
|
|
|
|
def refresh_config(self): |
|
|
config_list = [] |
|
|
for root, dirs, files in os.walk(self.config_dir): |
|
|
for file in files: |
|
|
if file.endswith(('.yaml', '.yml')): |
|
|
full_path = os.path.join(root, file) |
|
|
config_list.append(full_path) |
|
|
self.config_list = config_list |
|
|
|
|
|
def refresh_diffusion_transformer(self): |
|
|
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) |
|
|
|
|
|
def refresh_personalized_model(self): |
|
|
personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) |
|
|
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] |
|
|
|
|
|
def update_model_type(self, model_type): |
|
|
self.model_type = model_type |
|
|
|
|
|
def update_config(self, config_dropdown): |
|
|
self.config_path = config_dropdown |
|
|
self.config = OmegaConf.load(config_dropdown) |
|
|
print(f"Update config: {config_dropdown}") |
|
|
|
|
|
def update_diffusion_transformer(self, diffusion_transformer_dropdown): |
|
|
pass |
|
|
|
|
|
def update_base_model(self, base_model_dropdown, is_checkpoint_2=False): |
|
|
if not is_checkpoint_2: |
|
|
self.base_model_path = base_model_dropdown |
|
|
else: |
|
|
self.base_model_2_path = base_model_dropdown |
|
|
print(f"Update base model: {base_model_dropdown}") |
|
|
if base_model_dropdown == "none": |
|
|
return gr.update() |
|
|
if self.transformer is None and not is_checkpoint_2: |
|
|
gr.Info(f"Please select a pretrained model path.") |
|
|
print(f"Please select a pretrained model path.") |
|
|
return gr.update(value=None) |
|
|
elif self.transformer_2 is None and is_checkpoint_2: |
|
|
gr.Info(f"Please select a pretrained model path.") |
|
|
print(f"Please select a pretrained model path.") |
|
|
return gr.update(value=None) |
|
|
else: |
|
|
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) |
|
|
base_model_state_dict = {} |
|
|
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
|
|
for key in f.keys(): |
|
|
base_model_state_dict[key] = f.get_tensor(key) |
|
|
if not is_checkpoint_2: |
|
|
self.transformer.load_state_dict(base_model_state_dict, strict=False) |
|
|
else: |
|
|
self.transformer_2.load_state_dict(base_model_state_dict, strict=False) |
|
|
print("Update base model done") |
|
|
return gr.update() |
|
|
|
|
|
def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False): |
|
|
print(f"Update lora model: {lora_model_dropdown}") |
|
|
if lora_model_dropdown == "none": |
|
|
self.lora_model_path = "none" |
|
|
return gr.update() |
|
|
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) |
|
|
if not is_checkpoint_2: |
|
|
self.lora_model_path = lora_model_dropdown |
|
|
else: |
|
|
self.lora_model_2_path = lora_model_dropdown |
|
|
return gr.update() |
|
|
|
|
|
def clear_cache(self,): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
def auto_model_clear_cache(self, model): |
|
|
origin_device = model.device |
|
|
model = model.to("cpu") |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
model = model.to(origin_device) |
|
|
|
|
|
def input_check(self, |
|
|
resize_method, |
|
|
generation_method, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
control_video, |
|
|
is_api = False, |
|
|
): |
|
|
if self.transformer is None: |
|
|
if is_api: |
|
|
return "", f"Please select a pretrained model path." |
|
|
else: |
|
|
raise gr.Error(f"Please select a pretrained model path.") |
|
|
|
|
|
if control_video is not None and self.model_type == "Inpaint": |
|
|
if is_api: |
|
|
return "", f"If specifying the control video, please set the model_type == \"Control\". " |
|
|
else: |
|
|
raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") |
|
|
|
|
|
if control_video is None and self.model_type == "Control": |
|
|
if is_api: |
|
|
return "", f"If set the model_type == \"Control\", please specifying the control video. " |
|
|
else: |
|
|
raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") |
|
|
|
|
|
if resize_method == "Resize according to Reference": |
|
|
if start_image is None and validation_video is None and control_video is None: |
|
|
if is_api: |
|
|
return "", f"Please upload an image when using \"Resize according to Reference\"." |
|
|
else: |
|
|
raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") |
|
|
|
|
|
if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: |
|
|
if is_api: |
|
|
return "", f"Please select an image to video pretrained model while using image to video." |
|
|
else: |
|
|
raise gr.Error(f"Please select an image to video pretrained model while using image to video.") |
|
|
|
|
|
if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation": |
|
|
if is_api: |
|
|
return "", f"Please select an image to video pretrained model while using long video generation." |
|
|
else: |
|
|
raise gr.Error(f"Please select an image to video pretrained model while using long video generation.") |
|
|
|
|
|
if start_image is None and end_image is not None: |
|
|
if is_api: |
|
|
return "", f"If specifying the ending image of the video, please specify a starting image of the video." |
|
|
else: |
|
|
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") |
|
|
return "", "OK" |
|
|
|
|
|
def get_height_width_from_reference( |
|
|
self, |
|
|
base_resolution, |
|
|
start_image, |
|
|
validation_video, |
|
|
control_video, |
|
|
): |
|
|
spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8 |
|
|
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} |
|
|
if self.model_type == "Inpaint": |
|
|
if validation_video is not None: |
|
|
original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size |
|
|
else: |
|
|
original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size |
|
|
else: |
|
|
original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size |
|
|
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) |
|
|
height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size] |
|
|
return height_slider, width_slider |
|
|
|
|
|
def save_outputs(self, is_image, length_slider, sample, fps): |
|
|
def save_results(): |
|
|
if not os.path.exists(self.savedir_sample): |
|
|
os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
index = len([path for path in os.listdir(self.savedir_sample)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
|
|
|
md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest() |
|
|
|
|
|
if is_image or length_slider == 1: |
|
|
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") |
|
|
print(f"Saving to {save_sample_path}") |
|
|
image = sample[0, :, 0] |
|
|
image = image.transpose(0, 1).transpose(1, 2) |
|
|
image = (image * 255).numpy().astype(np.uint8) |
|
|
image = Image.fromarray(image) |
|
|
image.save(save_sample_path) |
|
|
|
|
|
else: |
|
|
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") |
|
|
print(f"Saving to {save_sample_path}") |
|
|
save_videos_grid(sample, save_sample_path, fps=fps) |
|
|
return save_sample_path |
|
|
|
|
|
if self.ulysses_degree * self.ring_degree > 1: |
|
|
import torch.distributed as dist |
|
|
if dist.get_rank() == 0: |
|
|
save_sample_path = save_results() |
|
|
else: |
|
|
save_sample_path = None |
|
|
else: |
|
|
save_sample_path = save_results() |
|
|
return save_sample_path |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
overlap_video_length, |
|
|
partial_video_length, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
control_video, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
enable_teacache = None, |
|
|
teacache_threshold = None, |
|
|
num_skip_start_steps = None, |
|
|
teacache_offload = None, |
|
|
cfg_skip_ratio = None, |
|
|
enable_riflex = None, |
|
|
riflex_k = None, |
|
|
is_api = False, |
|
|
): |
|
|
pass |
|
|
|
|
|
def post_to_host( |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, |
|
|
prompt_textbox, negative_prompt_textbox, |
|
|
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, |
|
|
base_resolution, generation_method, length_slider, cfg_scale_slider, |
|
|
start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, |
|
|
ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None, |
|
|
teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None, |
|
|
): |
|
|
if start_image is not None: |
|
|
with open(start_image, 'rb') as file: |
|
|
file_content = file.read() |
|
|
start_image_encoded_content = base64.b64encode(file_content) |
|
|
start_image = start_image_encoded_content.decode('utf-8') |
|
|
|
|
|
if end_image is not None: |
|
|
with open(end_image, 'rb') as file: |
|
|
file_content = file.read() |
|
|
end_image_encoded_content = base64.b64encode(file_content) |
|
|
end_image = end_image_encoded_content.decode('utf-8') |
|
|
|
|
|
if validation_video is not None: |
|
|
with open(validation_video, 'rb') as file: |
|
|
file_content = file.read() |
|
|
validation_video_encoded_content = base64.b64encode(file_content) |
|
|
validation_video = validation_video_encoded_content.decode('utf-8') |
|
|
|
|
|
if validation_video_mask is not None: |
|
|
with open(validation_video_mask, 'rb') as file: |
|
|
file_content = file.read() |
|
|
validation_video_mask_encoded_content = base64.b64encode(file_content) |
|
|
validation_video_mask = validation_video_mask_encoded_content.decode('utf-8') |
|
|
|
|
|
if ref_image is not None: |
|
|
with open(ref_image, 'rb') as file: |
|
|
file_content = file.read() |
|
|
ref_image_encoded_content = base64.b64encode(file_content) |
|
|
ref_image = ref_image_encoded_content.decode('utf-8') |
|
|
|
|
|
datas = { |
|
|
"base_model_path": base_model_dropdown, |
|
|
"lora_model_path": lora_model_dropdown, |
|
|
"lora_alpha_slider": lora_alpha_slider, |
|
|
"prompt_textbox": prompt_textbox, |
|
|
"negative_prompt_textbox": negative_prompt_textbox, |
|
|
"sampler_dropdown": sampler_dropdown, |
|
|
"sample_step_slider": sample_step_slider, |
|
|
"resize_method": resize_method, |
|
|
"width_slider": width_slider, |
|
|
"height_slider": height_slider, |
|
|
"base_resolution": base_resolution, |
|
|
"generation_method": generation_method, |
|
|
"length_slider": length_slider, |
|
|
"cfg_scale_slider": cfg_scale_slider, |
|
|
"start_image": start_image, |
|
|
"end_image": end_image, |
|
|
"validation_video": validation_video, |
|
|
"validation_video_mask": validation_video_mask, |
|
|
"denoise_strength": denoise_strength, |
|
|
"seed_textbox": seed_textbox, |
|
|
|
|
|
"ref_image": ref_image, |
|
|
"enable_teacache": enable_teacache, |
|
|
"teacache_threshold": teacache_threshold, |
|
|
"num_skip_start_steps": num_skip_start_steps, |
|
|
"teacache_offload": teacache_offload, |
|
|
"cfg_skip_ratio": cfg_skip_ratio, |
|
|
"enable_riflex": enable_riflex, |
|
|
"riflex_k": riflex_k, |
|
|
} |
|
|
|
|
|
session = requests.session() |
|
|
session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")}) |
|
|
|
|
|
response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300) |
|
|
|
|
|
outputs = response.json() |
|
|
return outputs |
|
|
|
|
|
|
|
|
class Fun_Controller_Client: |
|
|
def __init__(self, scheduler_dict, savedir_sample): |
|
|
self.basedir = os.getcwd() |
|
|
if savedir_sample is None: |
|
|
self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
|
|
else: |
|
|
self.savedir_sample = savedir_sample |
|
|
os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
|
|
|
self.scheduler_dict = scheduler_dict |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
ref_image = None, |
|
|
enable_teacache = None, |
|
|
teacache_threshold = None, |
|
|
num_skip_start_steps = None, |
|
|
teacache_offload = None, |
|
|
cfg_skip_ratio = None, |
|
|
enable_riflex = None, |
|
|
riflex_k = None, |
|
|
): |
|
|
is_image = True if generation_method == "Image Generation" else False |
|
|
|
|
|
outputs = post_to_host( |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, |
|
|
prompt_textbox, negative_prompt_textbox, |
|
|
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, |
|
|
base_resolution, generation_method, length_slider, cfg_scale_slider, |
|
|
start_image, end_image, validation_video, validation_video_mask, denoise_strength, |
|
|
seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold, |
|
|
num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload, |
|
|
cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k, |
|
|
) |
|
|
|
|
|
try: |
|
|
base64_encoding = outputs["base64_encoding"] |
|
|
except: |
|
|
return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"] |
|
|
|
|
|
decoded_data = base64.b64decode(base64_encoding) |
|
|
|
|
|
if not os.path.exists(self.savedir_sample): |
|
|
os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
md5_hash = hashlib.md5(decoded_data).hexdigest() |
|
|
|
|
|
index = len([path for path in os.listdir(self.savedir_sample)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
|
|
|
if is_image or length_slider == 1: |
|
|
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") |
|
|
print(f"Saving to {save_sample_path}") |
|
|
with open(save_sample_path, "wb") as file: |
|
|
file.write(decoded_data) |
|
|
if gradio_version_is_above_4: |
|
|
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" |
|
|
else: |
|
|
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" |
|
|
else: |
|
|
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") |
|
|
print(f"Saving to {save_sample_path}") |
|
|
with open(save_sample_path, "wb") as file: |
|
|
file.write(decoded_data) |
|
|
if gradio_version_is_above_4: |
|
|
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" |
|
|
else: |
|
|
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" |
|
|
|