diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 76b3cb0d58..9adaed7ec1 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -312,14 +312,10 @@ def triton_flash_attn_fn( h=1 if multiquery else n_heads) if multiquery: - # Expanding a tensor does not allocate new memory, but only creates a new - # view on the existing tensor where a dimension of size one is expanded - # to a larger size by setting the stride to 0. - # - pytorch docs - # - # hopefully the kernels can utilize this and we're jot just wasting BW here - key = key.expand(*key.shape[:2], n_heads, key.size(-1)) - value = value.expand(*value.shape[:2], n_heads, value.size(-1)) + # necessary to repeat instead of expand tensor because + # output contains NaN in edge cases such as with head dimension = 8 + key = key.repeat(1, 1, n_heads, 1) + value = value.repeat(1, 1, n_heads, 1) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, diff --git a/tests/test_model.py b/tests/test_model.py index 0a13ecd145..5266b82622 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1385,3 +1385,46 @@ def test_hf_init(tmp_path, updated_params = next(model.parameters()).clone().data assert not torch.equal(original_params, updated_params) + + +@pytest.mark.gpu +def test_head_dim_8_triton_mqa_attn(batch_size=2): + test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') + test_cfg.device = torch.cuda.current_device() + + test_cfg.batch_size = batch_size + + hf_config = MPTConfig( + init_device='cpu', + d_model=128, + n_heads=16, + n_layers=1, + expansion_ratio=2, + max_seq_len=128, + emb_pdrop=0.1, + resid_pdrop=0.2, + attn_config={ + 'attn_impl': 'triton', + 'attn_type': 'multiquery_attention' + }, + ) + test_cfg.device = torch.cuda.current_device() + + tokenizer = build_tokenizer(test_cfg.tokenizer) + + mpt = MPTForCausalLM(hf_config) + + model = HuggingFaceModelWithZLoss(mpt, tokenizer, shift_labels=True) + + model = model.to(test_cfg.device) + batch = gen_random_batch(batch_size, test_cfg) + + assert batch['input_ids'].shape == torch.Size( + [batch_size, test_cfg.max_seq_len]) + + model.train() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + output = model(batch) + + assert not torch.isnan(output.logits).any()