Skip to content

Commit

Permalink
another assert
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2024
1 parent 98d72b0 commit ed59cee
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 0 additions & 1 deletion assert_tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def start(
tree_out = tree_out.cpu()
out = out.cpu()

print((tree_out - out).abs().amax())
output_atol = 1e-2 if use_cuda else 1e-5

assert torch.allclose(tree_out, out, atol = output_atol), '🟥 output is not the same'
Expand Down
4 changes: 3 additions & 1 deletion ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def tree_attn_decode(
# each machine (rank) takes care of a chunk of kv sequence within the world of many machines

if shard_kv_seq:
assert exists(k), 'keys and values must be passed if not already sharded across sequence'

rank, world_size = get_rank(), get_world_size()
k = k.chunk(world_size, dim = -2)
v = v.chunk(world_size, dim = -2)

k, v = (k[rank], v[rank]) if rank < len(k) else (None, None)

if exists(k) and exists(v):
if exists(k):
# calculate local output and derive numerator and denominator

use_triton = default(use_triton, q.is_cuda)
Expand Down

0 comments on commit ed59cee

Please sign in to comment.