-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: extract init/forward function in UNet2DConditionModel #6478
Conversation
- Add new function get_mid_block() to unet_2d_blocks.py
Thanks for your contributions! We have already started refactoring UNet and it will be cleaner and cleaner in the coming days. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cool, is there any relevant issue or PR there? |
There are several actually. We have started taking a bottom-up approach here. So, we're also refactoring many other building blocks such as the embeddings class, the ResNet2D class, etc. Some relevant PRs:
Does this help? Cc: @patrickvonplaten @DN6 and @yiyixuxu here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's actually a very nice refactor in my opinion - @DN6 @yiyixuxu @sayakpaul can you take a look here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for working on this! I love this PR!
@@ -240,6 +240,59 @@ def get_down_block( | |||
raise ValueError(f"{down_block_type} does not exist.") | |||
|
|||
|
|||
def get_mid_block(mid_block_type, block_out_channels, mid_block_scale_factor, dropout, act_fn, norm_num_groups, norm_eps, cross_attention_dim, transformer_layers_per_block, attention_head_dim, num_attention_heads, dual_cross_attention, use_linear_projection, upcast_attention, resnet_time_scale_shift, resnet_skip_time_act, attention_type, mid_block_only_cross_attention, cross_attention_norm, blocks_time_embed_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add type hints like we do for get_up_block
and get_down_block
encoder_hid_dim_type = "text_proj" | ||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) | ||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") | ||
time_embed_dim, timestep_input_dim = self._set_time_embed_layer(flip_sin_to_cos, freq_shift, block_out_channels, act_fn, time_embedding_type, time_embedding_dim, timestep_post_act, time_cond_proj_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's separate this function into _set_time_proj
and set_time_embedding
and make sure we have consistent naming logic across all these methods. i.e.
_set_time_proj()
setsself.time_proj
layer,_set_time_embedding()
setself.time_embedding
layer,_set_encoder_hid_proj()
setself.encoder_hid_proj
layer_set_class_embedding()
setsself.class_embedding
layer_set_add_embedding()
setsself.add_embedding
layer
IMO this is especially important for functions that update things in-place to help our users to have a good sense of what's been update without having to go into the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also gives a nice chronology of the layers to the readers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it also make sense to add a _
at the end of the function names to denote that the functions do things in-place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think _set
is a nice prefix to denote inplace operations. _
at the end is a bit unconventional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- currently there is only one impl for time_embedding layer, so
_set_time_embedding
is not necessary? _set
prefix also act as a indicator for internal use and should not be called directly by users._
at the end actually follows the pytorch naming style for in-place actions, but it might not be necessary if we already have a prefix?
self.mid_block = None | ||
else: | ||
raise ValueError(f"unknown mid_block_type : {mid_block_type}") | ||
self.mid_block = get_mid_block(mid_block_type, block_out_channels, mid_block_scale_factor, dropout, act_fn, norm_num_groups, norm_eps, cross_attention_dim, transformer_layers_per_block, attention_head_dim, num_attention_heads, dual_cross_attention, use_linear_projection, upcast_attention, resnet_time_scale_shift, resnet_skip_time_act, attention_type, mid_block_only_cross_attention, cross_attention_norm, blocks_time_embed_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
for layer_number_per_block in transformer_layers_per_block: | ||
if isinstance(layer_number_per_block, list): | ||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") | ||
self._check_config(down_block_types, up_block_types, only_cross_attention, block_out_channels, layers_per_block, cross_attention_dim, transformer_layers_per_block, reverse_transformer_layers_per_block, attention_head_dim, num_attention_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uff, what a lovely PR! My eyes feel better now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice!
…_block; rename _set_xxx function
image_embeds = added_cond_kwargs.get("image_embeds") | ||
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) | ||
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) | ||
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs) | |
if self.encoder_hid_proj is not None: | |
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs) |
Adding a short if statement here makes it a bit easier to understand IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
process_encoder_hidden_states
(or rename to process_added_cond
?) without if statement outside could be more easily extendable for future changes IMO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I'd rename the function to project_encoder_hidden_states
because we linearly project the encoder hidden states here. Then personally I think it's much better to have in the if
statement because most of the SD models do not use this function (it's only the IF models really). If we have it in an if-statement, people reading the code that know SD will see directly that this function is not applied.
So I guess the following would be ideal for me:
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.project_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a super cool PR @ultranity! Very important to keep diffusers readable and in order to fight the bloated code.
I just have some minor suggestions - overall very happy to merge this one soon!
84a6de6
to
f899f92
Compare
About the last failing code quality check, as versatile_diffusion is deprecated, I prefer to remove the Copy from tag directly How do you think @patrickvonplaten |
Agree 100%! Could you maybe remove the Copy-from from versatile diffusion? |
Great job @ultranity ! |
…ingface#6478) * - extract function for stage in UNet2DConditionModel init & forward - Add new function get_mid_block() to unet_2d_blocks.py * add type hint to get_mid_block aligned with get_up_block and get_down_block; rename _set_xxx function * add type hint and use keyword arguments * remove `copy from` in versatile diffusion
What does this PR do?
Current UNet2DConditionModel mixed with different variants have a very long impl while widely used, add some stage function might help developers/researchers to better understand the code and make it easier to hack
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten and @sayakpaul