Skip to content

Commit

Permalink
NPU adaption for RMSNorm (#10534)
Browse files Browse the repository at this point in the history
* NPU adaption for RMSNorm

* NPU adaption for RMSNorm

---------

Co-authored-by: J石页 <[email protected]>
  • Loading branch information
leisuzz and J石页 authored Jan 16, 2025
1 parent 17d99c4 commit cecada5
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ..utils import is_torch_version
from ..utils import is_torch_npu_available, is_torch_version
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings

Expand Down Expand Up @@ -505,19 +505,30 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool
self.bias = nn.Parameter(torch.zeros(dim))

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
if is_torch_npu_available():
import torch_npu

if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
hidden_states = hidden_states.to(input_dtype)
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
hidden_states = hidden_states.to(input_dtype)

return hidden_states

Expand Down

0 comments on commit cecada5

Please sign in to comment.