diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index 27a8f02..859bb65 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -254,19 +254,17 @@ def backward(ctx, do): device = q.device dq = torch.zeros(q.shape, device = device, dtype = torch.float32) - dk = torch.zeros(k.shape, device = device, dtype = torch.float32) - dv = torch.zeros(v.shape, device = device, dtype = torch.float32) + dk = torch.zeros_like(k, device = device) + dv = torch.zeros_like(v, device = device) - # k and v will have 16 bits, while dk, dv needs to be kept at 32 - # view everything as int for ring passing - # for minimizing communication + # k and v will have 16 bits, and dk and dv can also be accumulated safely with the same type, i think + # view everything as float32 for ring passing - k_dtype, v_dtype = k.dtype, v.dtype + assert k.dtype == v.dtype + kv_dtype = k.dtype - k, v = map(lambda t: t.view(torch.float32), (k, v)) - kv = torch.cat((k, v), dim = -1) - - kv_and_dkv = torch.stack((kv, dk, dv)) + k, v, dk, dv = map(lambda t: t.view(torch.float32), (k, v, dk, dv)) + kv_and_dkv = torch.stack((k, v, dk, dv)) # receive buffers, to be alternated with sent buffer @@ -279,12 +277,11 @@ def backward(ctx, do): for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(kv_and_dkv, mask, receive_buffers = (receive_kv_and_dkv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size): - kv, dk, dv = kv_and_dkv + k, v, dk, dv = kv_and_dkv - # reconstitute correct types for k, v, dk, dv + # view k, v, dk, dv as the correct type of either float16 or bfloat16 - k, v = kv.chunk(2, dim = -1) - k, v = k.view(k_dtype), v.view(v_dtype) + k, v, dk, dv = map(lambda t: t.view(kv_dtype), (k, v, dk, dv)) # translate key padding mask to bias @@ -313,7 +310,7 @@ def backward(ctx, do): # use flash attention backwards kernel to calculate dq, dk, dv and accumulate if need_accum: - ring_dq = torch.empty_like(q) + ring_dq = torch.empty(q.shape, device = device, dtype = torch.float32) ring_dk = torch.empty_like(k) ring_dv = torch.empty_like(v) @@ -344,13 +341,15 @@ def backward(ctx, do): if not ring_reduce_col: continue - dkv = kv_and_dkv[1:] + dkv = kv_and_dkv[2:] max_ring_passes = default(max_ring_passes, ring_size) dkv = ring_pass(ring_size - max_ring_passes + 1, dkv) dk, dv = dkv + dk, dv = map(lambda t: t.view(kv_dtype), (dk, dv)) + dq, dk, dv = map(lambda t: t.to(dtype), (dq, dk, dv)) return dq, dk, dv, None, None, None, None, None, None, None diff --git a/setup.py b/setup.py index b904a75..95ab65e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.3.17', + version = '0.3.18', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',