From f3ed323ddcb0be34b6698e1d2ef7b25c21069bc2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 15 Aug 2024 09:59:19 -0700 Subject: [PATCH] readme --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index cd974c2..83fd41d 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,16 @@ attended = attn(tokens) assert attended.shape == tokens.shape ``` +This repository also contains an implementation of Tree Attention Decoding from Shyam et al. + +It can be imported and used as follows + +```python +from ring_attention_pytorch import tree_attn_decode + +out = tree_attn_decode(q, k, v) # where q, k, v exists across all machines +``` + ## Test First install requirements @@ -62,6 +72,12 @@ Then say testing autoregressive striped ring attention on cuda would be $ python assert.py --use-cuda --causal --striped-ring-attn ``` +Testing tree attention would be + +```bash +$ python assert_tree_attn.py --use-cuda --seq-len 8192 +``` + ## Todo - [x] make it work with derived causal mask based on rank and chunk sizes