From 7f2c92219ba72bab409fc8421c378f6f732853fb Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 12 Jul 2024 20:48:30 -0700 Subject: [PATCH 1/2] Updated sdpa with enable_gqa=True --- README.md | 14 ++++++++++++++ model.py | 4 +--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8ae96eb..ba65cb7 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,20 @@ Benchmarks run on one GCD of a MI-250x. | Llama-2-7B | Base | 76.33 | 1028.70 | | | 8-bit | 101.86 | 700.06 | +### Using Grouped Query Attention +Benchmarks run on 1 NVIDIA H100. + +Using ```bash +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +``` + +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | +| -------- | ------- | ------ | ------ | +| Llama-2-7B | Base | 146.66 | 1938.12 | +| | 8-bit | 233.50 | 1543.55 | +| | 4-bit (G=32) | 267.11 | 1103.14 | + + ## Generate Text Model definition in `model.py`, generation code in `generate.py`. diff --git a/model.py b/model.py index b89a19a..06cf365 100644 --- a/model.py +++ b/model.py @@ -195,9 +195,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona if self.kv_cache is not None: k, v = self.kv_cache.update(input_pos, k, v) - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, enable_gqa=True) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) From 9dce6a4d267ca036cbaf5f862d21467c7b2f488a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 12 Jul 2024 20:57:56 -0700 Subject: [PATCH 2/2] Lint fixes --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ba65cb7..ff9297d 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,8 @@ Benchmarks run on one GCD of a MI-250x. ### Using Grouped Query Attention Benchmarks run on 1 NVIDIA H100. -Using ```bash +Using +```bash export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf ``` @@ -142,7 +143,6 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf | | 8-bit | 233.50 | 1543.55 | | | 4-bit (G=32) | 267.11 | 1103.14 | - ## Generate Text Model definition in `model.py`, generation code in `generate.py`.