Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu committed Jan 13, 2025
1 parent 42d3a6a commit becbcd6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)

def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor) -> torch.Tensor:

def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (
scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)
).chunk(2, dim=1)
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states

Expand Down Expand Up @@ -235,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]

@register_to_config
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import requests_mock
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size, compute_module_sizes
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
Expand Down
1 change: 0 additions & 1 deletion tests/models/transformers/test_models_transformer_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import unittest

import pytest
import torch

from diffusers import SanaTransformer2DModel
Expand Down

0 comments on commit becbcd6

Please sign in to comment.