Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
qingquansong committed Aug 30, 2024
1 parent c1bb445 commit ac2fc57
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
14 changes: 5 additions & 9 deletions src/liger_kernel/ops/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def layer_norm_forward(X, W, B, eps):
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
Expand All @@ -163,10 +162,10 @@ def layer_norm_forward(X, W, B, eps):
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), Mean, RSTD, BLOCK_SIZE, num_warps
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps


def layer_norm_backward(dY, X, W, B, Mean, RSTD, BLOCK_SIZE, num_warps):
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
Expand All @@ -177,6 +176,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD, BLOCK_SIZE, num_warps):
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)

BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

Expand Down Expand Up @@ -216,17 +216,13 @@ class LigerLayerNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps):
Y, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
ctx.save_for_backward(X, W, B, Mean, RSTD)
return Y

@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
BLOCK_SIZE = ctx.BLOCK_SIZE
num_warps = ctx.num_warps
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD, BLOCK_SIZE, num_warps)
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
return DX, DW, DB, None
19 changes: 11 additions & 8 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), r, BLOCK_SIZE, num_warps, casting_mode
return Y.view(*shape), X, r, BLOCK_SIZE, num_warps, casting_mode


def rms_norm_backward(dY, X, W, r, eps, offset, casting_mode, BLOCK_SIZE, num_warps):
Expand Down Expand Up @@ -291,7 +291,7 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
X: (B, T, H) or (BxT, H)
W: (H,)
"""
Y, r, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
Y, X, r, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
X, W, eps, offset, casting_mode
)
ctx.eps = eps
Expand All @@ -309,12 +309,15 @@ def backward(ctx, dY):
Y: (B, T, H) or (BxT, H)
"""
X, W, r = ctx.saved_tensors
eps = ctx.eps
offset = ctx.offset
casting_mode = ctx.casting_mode
BLOCK_SIZE = ctx.BLOCK_SIZE
num_warps = ctx.num_warps
dX, dW = rms_norm_backward(
dY, X, W, r, eps, offset, casting_mode, BLOCK_SIZE, num_warps
dY,
X,
W,
r,
ctx.eps,
ctx.offset,
ctx.casting_mode,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None

0 comments on commit ac2fc57

Please sign in to comment.