diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index fe3823e32acf..7db4d3d17d2f 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import is_torch_version +from ..utils import is_torch_npu_available, is_torch_version from .activations import get_activation from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings @@ -505,19 +505,30 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool self.bias = nn.Parameter(torch.zeros(dim)) def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - if self.weight is not None: - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight + if is_torch_npu_available(): + import torch_npu + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] if self.bias is not None: hidden_states = hidden_states + self.bias else: - hidden_states = hidden_states.to(input_dtype) + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + if self.bias is not None: + hidden_states = hidden_states + self.bias + else: + hidden_states = hidden_states.to(input_dtype) return hidden_states