|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch, random, json, spaces, time |
|
|
from ulid import ULID |
|
|
from diffsynth.pipelines.qwen_image import ( |
|
|
QwenImagePipeline, ModelConfig, |
|
|
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode |
|
|
) |
|
|
from safetensors.torch import save_file |
|
|
import torch |
|
|
from PIL import Image |
|
|
from utils import repo_utils, image_utils, prompt_utils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
URL_PUBLIC = "https://huggingface.co/spaces/AiSudo/Qwen-Image-to-LoRA/blob/main" |
|
|
DTYPE = torch.bfloat16 |
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
vram_config_disk_offload = { |
|
|
"offload_dtype": "disk", |
|
|
"offload_device": "disk", |
|
|
"onload_dtype": "disk", |
|
|
"onload_device": "disk", |
|
|
"preparing_dtype": torch.bfloat16, |
|
|
"preparing_device": "cuda", |
|
|
"computation_dtype": torch.bfloat16, |
|
|
"computation_device": "cuda", |
|
|
} |
|
|
|
|
|
|
|
|
pipe_lora = QwenImagePipeline.from_pretrained( |
|
|
torch_dtype=torch.bfloat16, |
|
|
device="cuda", |
|
|
model_configs=[ |
|
|
ModelConfig( |
|
|
download_source="huggingface", |
|
|
model_id="DiffSynth-Studio/General-Image-Encoders", |
|
|
origin_file_pattern="SigLIP2-G384/model.safetensors", |
|
|
**vram_config_disk_offload |
|
|
), |
|
|
ModelConfig( |
|
|
download_source="huggingface", |
|
|
model_id="DiffSynth-Studio/General-Image-Encoders", |
|
|
origin_file_pattern="DINOv3-7B/model.safetensors", |
|
|
**vram_config_disk_offload |
|
|
), |
|
|
ModelConfig( |
|
|
download_source="huggingface", |
|
|
model_id="DiffSynth-Studio/Qwen-Image-i2L", |
|
|
origin_file_pattern="Qwen-Image-i2L-Style.safetensors", |
|
|
**vram_config_disk_offload |
|
|
), |
|
|
], |
|
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), |
|
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, |
|
|
) |
|
|
|
|
|
vram_config = { |
|
|
"offload_dtype": "disk", |
|
|
"offload_device": "disk", |
|
|
"onload_dtype": torch.bfloat16, |
|
|
"onload_device": "cuda", |
|
|
"preparing_dtype": torch.bfloat16, |
|
|
"preparing_device": "cuda", |
|
|
"computation_dtype": torch.bfloat16, |
|
|
"computation_device": "cuda", |
|
|
} |
|
|
|
|
|
pipe_imagen = QwenImagePipeline.from_pretrained( |
|
|
torch_dtype=torch.bfloat16, |
|
|
device="cuda", |
|
|
model_configs=[ |
|
|
ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), |
|
|
ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), |
|
|
ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), |
|
|
], |
|
|
tokenizer_config=ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), |
|
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, |
|
|
) |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate_lora( |
|
|
input_images, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
|
|
|
ulid = str(ULID()).lower()[:12] |
|
|
print(f"ulid: {ulid}") |
|
|
|
|
|
if not input_images: |
|
|
print("images are empty.") |
|
|
return False |
|
|
|
|
|
input_images = [Image.open(filepath).convert("RGB") for filepath, _ in input_images] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
embs = QwenImageUnit_Image2LoRAEncode().process(pipe_lora, image2lora_images=input_images) |
|
|
lora = QwenImageUnit_Image2LoRADecode().process(pipe_lora, **embs)["lora"] |
|
|
|
|
|
lora_name = f"{ulid}.safetensors" |
|
|
lora_path = f"loras/{lora_name}" |
|
|
|
|
|
save_file(lora, lora_path) |
|
|
|
|
|
return lora_name, gr.update(interactive=True, value=lora_path), gr.update(interactive=True) |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_image( |
|
|
lora_name, |
|
|
prompt, |
|
|
negative_prompt="blurry ugly bad", |
|
|
width=1024, |
|
|
height=1024, |
|
|
seed=42, |
|
|
randomize_seed=True, |
|
|
guidance_scale=3.5, |
|
|
num_inference_steps=8, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
lora_path = f"loras/{lora_name}" |
|
|
pipe_imagen.clear_lora() |
|
|
pipe_imagen.load_lora(pipe_imagen.dit, lora_path) |
|
|
|
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
output_image = pipe_imagen( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=num_inference_steps, |
|
|
width=width, |
|
|
height=height, |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
return output_image, seed |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def read_file(path: str) -> str: |
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
return content |
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 960px; |
|
|
} |
|
|
h3{ |
|
|
text-align: center; |
|
|
display:block; |
|
|
} |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
with open('examples/0_examples.json', 'r') as file: examples = json.load(file) |
|
|
print(examples) |
|
|
with gr.Blocks() as demo: |
|
|
with gr.Column(elem_id="col-container"): |
|
|
with gr.Column(): |
|
|
gr.HTML(read_file("static/header.html")) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_images = gr.Gallery( |
|
|
label="Input images", |
|
|
file_types=["image"], |
|
|
show_label=False, |
|
|
elem_id="gallery", |
|
|
columns=2, |
|
|
object_fit="cover", |
|
|
height=300) |
|
|
|
|
|
lora_button = gr.Button("Generate LoRA", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
lora_name = gr.Textbox(label="Generated LoRA path",lines=2, interactive=False) |
|
|
lora_download = gr.DownloadButton(label=f"Download LoRA", interactive=False) |
|
|
with gr.Column(elem_id='imagen-container') as imagen_container: |
|
|
gr.Markdown("### After your LoRA is ready, you can try generate image here.") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
show_label=False, |
|
|
lines=2, |
|
|
placeholder="Enter your prompt", |
|
|
value="a man in a fishing boat.", |
|
|
container=False, |
|
|
) |
|
|
|
|
|
imagen_button = gr.Button("Generate Image", variant="primary", interactive=False) |
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative prompt", |
|
|
lines=2, |
|
|
container=False, |
|
|
placeholder="Enter your negative prompt", |
|
|
value="blurry ugly bad" |
|
|
) |
|
|
num_inference_steps = gr.Slider( |
|
|
label="Steps", |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=25, |
|
|
) |
|
|
with gr.Row(): |
|
|
width = gr.Slider( |
|
|
label="Width", |
|
|
minimum=512, |
|
|
maximum=1280, |
|
|
step=32, |
|
|
value=768, |
|
|
) |
|
|
|
|
|
height = gr.Slider( |
|
|
label="Height", |
|
|
minimum=512, |
|
|
maximum=1280, |
|
|
step=32, |
|
|
value=1024, |
|
|
) |
|
|
with gr.Row(): |
|
|
seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=MAX_SEED, |
|
|
step=1, |
|
|
value=42, |
|
|
) |
|
|
guidance_scale = gr.Slider( |
|
|
label="Guidance scale", |
|
|
minimum=0.0, |
|
|
maximum=10.0, |
|
|
step=0.1, |
|
|
value=3.5, |
|
|
) |
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(label="Generated image", show_label=False) |
|
|
|
|
|
gr.Examples(examples=examples, inputs=[input_images]) |
|
|
gr.Markdown(read_file("static/footer.md")) |
|
|
|
|
|
lora_button.click( |
|
|
fn=generate_lora, |
|
|
inputs=[ |
|
|
input_images |
|
|
], |
|
|
outputs=[lora_name, lora_download, imagen_button], |
|
|
) |
|
|
imagen_button.click( |
|
|
fn=generate_image, |
|
|
inputs=[ |
|
|
lora_name, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
width, |
|
|
height, |
|
|
seed, |
|
|
randomize_seed, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
], |
|
|
outputs=[output_image, seed], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(mcp_server=True, css=css) |
|
|
|