|
|
import functools |
|
|
import glob |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import types |
|
|
import warnings |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
|
|
from diffusers.loaders.single_file_model import FromOriginalModelMixin |
|
|
from diffusers.models.attention import FeedForward |
|
|
from diffusers.models.attention_processor import Attention |
|
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm |
|
|
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, |
|
|
scale_lora_layers, unscale_lora_layers) |
|
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
|
from torch import nn |
|
|
from .fuser import (get_sequence_parallel_rank, |
|
|
get_sequence_parallel_world_size, get_sp_group, |
|
|
init_distributed_environment, initialize_model_parallel, |
|
|
xFuserLongContextAttention) |
|
|
|
|
|
def apply_rotary_emb_qwen( |
|
|
x: torch.Tensor, |
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
|
|
use_real: bool = True, |
|
|
use_real_unbind_dim: int = -1, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings |
|
|
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are |
|
|
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting |
|
|
tensors contain rotary embeddings and are returned as real tensors. |
|
|
|
|
|
Args: |
|
|
x (`torch.Tensor`): |
|
|
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply |
|
|
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
|
|
""" |
|
|
if use_real: |
|
|
cos, sin = freqs_cis |
|
|
cos = cos[None, None] |
|
|
sin = sin[None, None] |
|
|
cos, sin = cos.to(x.device), sin.to(x.device) |
|
|
|
|
|
if use_real_unbind_dim == -1: |
|
|
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) |
|
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
|
|
elif use_real_unbind_dim == -2: |
|
|
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) |
|
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1) |
|
|
else: |
|
|
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") |
|
|
|
|
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
|
|
|
|
|
return out |
|
|
else: |
|
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
|
|
freqs_cis = freqs_cis.unsqueeze(1) |
|
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) |
|
|
|
|
|
return x_out.type_as(x) |
|
|
|
|
|
|
|
|
class QwenImageMultiGPUsAttnProcessor2_0: |
|
|
r""" |
|
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on |
|
|
query and key vectors, but does not include spatial normalization. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
attn: Attention, |
|
|
hidden_states: torch.FloatTensor, |
|
|
encoder_hidden_states: torch.FloatTensor = None, |
|
|
encoder_hidden_states_mask: torch.FloatTensor = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
|
) -> torch.FloatTensor: |
|
|
if encoder_hidden_states is None: |
|
|
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") |
|
|
|
|
|
seq_txt = encoder_hidden_states.shape[1] |
|
|
|
|
|
|
|
|
img_query = attn.to_q(hidden_states) |
|
|
img_key = attn.to_k(hidden_states) |
|
|
img_value = attn.to_v(hidden_states) |
|
|
|
|
|
|
|
|
txt_query = attn.add_q_proj(encoder_hidden_states) |
|
|
txt_key = attn.add_k_proj(encoder_hidden_states) |
|
|
txt_value = attn.add_v_proj(encoder_hidden_states) |
|
|
|
|
|
|
|
|
img_query = img_query.unflatten(-1, (attn.heads, -1)) |
|
|
img_key = img_key.unflatten(-1, (attn.heads, -1)) |
|
|
img_value = img_value.unflatten(-1, (attn.heads, -1)) |
|
|
|
|
|
txt_query = txt_query.unflatten(-1, (attn.heads, -1)) |
|
|
txt_key = txt_key.unflatten(-1, (attn.heads, -1)) |
|
|
txt_value = txt_value.unflatten(-1, (attn.heads, -1)) |
|
|
|
|
|
|
|
|
if attn.norm_q is not None: |
|
|
img_query = attn.norm_q(img_query) |
|
|
if attn.norm_k is not None: |
|
|
img_key = attn.norm_k(img_key) |
|
|
if attn.norm_added_q is not None: |
|
|
txt_query = attn.norm_added_q(txt_query) |
|
|
if attn.norm_added_k is not None: |
|
|
txt_key = attn.norm_added_k(txt_key) |
|
|
|
|
|
|
|
|
if image_rotary_emb is not None: |
|
|
img_freqs, txt_freqs = image_rotary_emb |
|
|
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False) |
|
|
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False) |
|
|
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False) |
|
|
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
half_dtypes = (torch.float16, torch.bfloat16) |
|
|
def half(x): |
|
|
return x if x.dtype in half_dtypes else x.to(dtype) |
|
|
|
|
|
joint_hidden_states = xFuserLongContextAttention()( |
|
|
None, |
|
|
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False, |
|
|
joint_tensor_query=half(txt_query), |
|
|
joint_tensor_key=half(txt_key), |
|
|
joint_tensor_value=half(txt_value), |
|
|
joint_strategy='front', |
|
|
) |
|
|
|
|
|
|
|
|
joint_hidden_states = joint_hidden_states.flatten(2, 3) |
|
|
joint_hidden_states = joint_hidden_states.to(img_query.dtype) |
|
|
|
|
|
|
|
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :] |
|
|
img_attn_output = joint_hidden_states[:, seq_txt:, :] |
|
|
|
|
|
|
|
|
img_attn_output = attn.to_out[0](img_attn_output) |
|
|
if len(attn.to_out) > 1: |
|
|
img_attn_output = attn.to_out[1](img_attn_output) |
|
|
|
|
|
txt_attn_output = attn.to_add_out(txt_attn_output) |
|
|
|
|
|
return img_attn_output, txt_attn_output |