Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains authored Aug 15, 2024
1 parent 103379f commit f3ed323
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ attended = attn(tokens)
assert attended.shape == tokens.shape
```

This repository also contains an implementation of <a href="https://arxiv.org/abs/2408.04093">Tree Attention Decoding</a> from Shyam et al.

It can be imported and used as follows

```python
from ring_attention_pytorch import tree_attn_decode

out = tree_attn_decode(q, k, v) # where q, k, v exists across all machines
```

## Test

First install requirements
Expand All @@ -62,6 +72,12 @@ Then say testing autoregressive striped ring attention on cuda would be
$ python assert.py --use-cuda --causal --striped-ring-attn
```

Testing tree attention would be

```bash
$ python assert_tree_attn.py --use-cuda --seq-len 8192
```

## Todo

- [x] make it work with derived causal mask based on rank and chunk sizes
Expand Down

0 comments on commit f3ed323

Please sign in to comment.