Skip to content

Commit

Permalink
add type hint and use keyword arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
ultranity committed Jan 16, 2024
1 parent 4001c86 commit f899f92
Showing 1 changed file with 90 additions and 62 deletions.
152 changes: 90 additions & 62 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,16 @@ def __init__(

# Check inputs
self._check_config(
down_block_types,
up_block_types,
only_cross_attention,
block_out_channels,
layers_per_block,
cross_attention_dim,
transformer_layers_per_block,
reverse_transformer_layers_per_block,
attention_head_dim,
num_attention_heads,
down_block_types=down_block_types,
up_block_types=up_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
)

# input
Expand All @@ -258,7 +258,11 @@ def __init__(

# time
time_embed_dim, timestep_input_dim = self._set_time_proj(
flip_sin_to_cos, freq_shift, block_out_channels, time_embedding_type, time_embedding_dim
time_embedding_type,
block_out_channels=block_out_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
time_embedding_dim=time_embedding_dim,
)

self.time_embedding = TimestepEmbedding(
Expand All @@ -269,28 +273,32 @@ def __init__(
cond_proj_dim=time_cond_proj_dim,
)

self._set_encoder_hid_proj(cross_attention_dim, encoder_hid_dim, encoder_hid_dim_type)
self._set_encoder_hid_proj(
encoder_hid_dim_type,
cross_attention_dim=cross_attention_dim,
encoder_hid_dim=encoder_hid_dim,
)

# class embedding
self._set_class_embedding(
act_fn,
class_embed_type,
num_class_embeds,
projection_class_embeddings_input_dim,
time_embed_dim,
timestep_input_dim,
act_fn=act_fn,
num_class_embeds=num_class_embeds,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
time_embed_dim=time_embed_dim,
timestep_input_dim=timestep_input_dim,
)

self._set_add_embedding(
flip_sin_to_cos,
freq_shift,
cross_attention_dim,
encoder_hid_dim,
addition_embed_type,
addition_time_embed_dim,
projection_class_embeddings_input_dim,
addition_embed_type_num_heads,
time_embed_dim,
addition_embed_type_num_heads=addition_embed_type_num_heads,
addition_time_embed_dim=addition_time_embed_dim,
cross_attention_dim=cross_attention_dim,
encoder_hid_dim=encoder_hid_dim,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
time_embed_dim=time_embed_dim,
)

if time_embedding_act_fn is None:
Expand Down Expand Up @@ -468,20 +476,20 @@ def __init__(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)

self._set_pos_net_if_use_gligen(cross_attention_dim, attention_type)
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)

def _check_config(
self,
down_block_types,
up_block_types,
only_cross_attention,
block_out_channels,
layers_per_block,
cross_attention_dim,
transformer_layers_per_block,
reverse_transformer_layers_per_block,
attention_head_dim,
num_attention_heads,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
layers_per_block: [int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
reverse_transformer_layers_per_block: bool,
attention_head_dim: int,
num_attention_heads: Optional[Union[int, Tuple[int]]],
):
if len(down_block_types) != len(up_block_types):
raise ValueError(
Expand Down Expand Up @@ -522,7 +530,14 @@ def _check_config(
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

def _set_time_proj(self, flip_sin_to_cos, freq_shift, block_out_channels, time_embedding_type, time_embedding_dim):
def _set_time_proj(
self,
time_embedding_type: str,
block_out_channels: int,
flip_sin_to_cos: bool,
freq_shift: float,
time_embedding_dim: int,
) -> Tuple[int, int]:
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
Expand All @@ -543,7 +558,12 @@ def _set_time_proj(self, flip_sin_to_cos, freq_shift, block_out_channels, time_e

return time_embed_dim, timestep_input_dim

def _set_encoder_hid_proj(self, cross_attention_dim, encoder_hid_dim, encoder_hid_dim_type):
def _set_encoder_hid_proj(
self,
encoder_hid_dim_type: Optional[str],
cross_attention_dim: Union[int, Tuple[int]],
encoder_hid_dim: Optional[int],
):
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
Expand Down Expand Up @@ -580,12 +600,12 @@ def _set_encoder_hid_proj(self, cross_attention_dim, encoder_hid_dim, encoder_hi

def _set_class_embedding(
self,
act_fn,
class_embed_type,
num_class_embeds,
projection_class_embeddings_input_dim,
time_embed_dim,
timestep_input_dim,
class_embed_type: Optional[str],
act_fn: str,
num_class_embeds: Optional[int],
projection_class_embeddings_input_dim: Optional[int],
time_embed_dim: int,
timestep_input_dim: int,
):
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Expand Down Expand Up @@ -617,15 +637,15 @@ def _set_class_embedding(

def _set_add_embedding(
self,
flip_sin_to_cos,
freq_shift,
cross_attention_dim,
encoder_hid_dim,
addition_embed_type,
addition_time_embed_dim,
projection_class_embeddings_input_dim,
addition_embed_type_num_heads,
time_embed_dim,
addition_embed_type: str,
addition_embed_type_num_heads: int,
addition_time_embed_dim: Optional[int],
flip_sin_to_cos: bool,
freq_shift: float,
cross_attention_dim: Optional[int],
encoder_hid_dim: Optional[int],
projection_class_embeddings_input_dim: Optional[int],
time_embed_dim: int,
):
if addition_embed_type == "text":
if encoder_hid_dim is not None:
Expand Down Expand Up @@ -655,7 +675,7 @@ def _set_add_embedding(
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")

def _set_pos_net_if_use_gligen(self, cross_attention_dim, attention_type):
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
if attention_type in ["gated", "gated-text-image"]:
positive_len = 768
if isinstance(cross_attention_dim, int):
Expand Down Expand Up @@ -889,7 +909,9 @@ def unload_lora(self):
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

def get_time_embed(self, sample, timestep):
def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]:
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
Expand All @@ -913,7 +935,7 @@ def get_time_embed(self, sample, timestep):
t_emb = t_emb.to(dtype=sample.dtype)
return t_emb

def get_class_embed(self, sample, class_labels):
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
class_emb = None
if self.class_embedding is not None:
if class_labels is None:
Expand All @@ -929,7 +951,9 @@ def get_class_embed(self, sample, class_labels):
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
return class_emb

def get_aug_embed(self, encoder_hidden_states, added_cond_kwargs, emb):
def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
) -> Optional[torch.Tensor]:
aug_emb = None
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
Expand Down Expand Up @@ -979,7 +1003,7 @@ def get_aug_embed(self, encoder_hidden_states, added_cond_kwargs, emb):
aug_emb = self.add_embedding(image_embs, hint)
return aug_emb

def process_encoder_hidden_states(self, encoder_hidden_states, added_cond_kwargs):
def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
Expand Down Expand Up @@ -1121,18 +1145,20 @@ def forward(
sample = 2 * sample - 1.0

# 1. time
t_emb = self.get_time_embed(sample, timestep)
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None

class_emb = self.get_class_embed(sample, class_labels)
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb

aug_emb = self.get_aug_embed(encoder_hidden_states, added_cond_kwargs, emb)
aug_emb = self.get_aug_embed(
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb
sample = torch.cat([sample, hint], dim=1)
Expand All @@ -1141,7 +1167,9 @@ def forward(
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)

encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)
encoder_hidden_states = self.process_encoder_hidden_states(
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)

# 2. pre-process
sample = self.conv_in(sample)
Expand Down

0 comments on commit f899f92

Please sign in to comment.