Skip to content

Commit

Permalink
Use torch.repeat instead of expand on key & value in Triton MQA to pr…
Browse files Browse the repository at this point in the history
…event NaNs with certain h_dims (#442)
  • Loading branch information
sashaDoubov authored Jul 8, 2023
1 parent 62e2fea commit 86a99e2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
12 changes: 4 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,10 @@ def triton_flash_attn_fn(
h=1 if multiquery else n_heads)

if multiquery:
# Expanding a tensor does not allocate new memory, but only creates a new
# view on the existing tensor where a dimension of size one is expanded
# to a larger size by setting the stride to 0.
# - pytorch docs
#
# hopefully the kernels can utilize this and we're jot just wasting BW here
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
# necessary to repeat instead of expand tensor because
# output contains NaN in edge cases such as with head dimension = 8
key = key.repeat(1, 1, n_heads, 1)
value = value.repeat(1, 1, n_heads, 1)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,3 +1385,46 @@ def test_hf_init(tmp_path,
updated_params = next(model.parameters()).clone().data

assert not torch.equal(original_params, updated_params)


@pytest.mark.gpu
def test_head_dim_8_triton_mqa_attn(batch_size=2):
test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
test_cfg.device = torch.cuda.current_device()

test_cfg.batch_size = batch_size

hf_config = MPTConfig(
init_device='cpu',
d_model=128,
n_heads=16,
n_layers=1,
expansion_ratio=2,
max_seq_len=128,
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
'attn_impl': 'triton',
'attn_type': 'multiquery_attention'
},
)
test_cfg.device = torch.cuda.current_device()

tokenizer = build_tokenizer(test_cfg.tokenizer)

mpt = MPTForCausalLM(hf_config)

model = HuggingFaceModelWithZLoss(mpt, tokenizer, shift_labels=True)

model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)

assert batch['input_ids'].shape == torch.Size(
[batch_size, test_cfg.max_seq_len])

model.train()

with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
output = model(batch)

assert not torch.isnan(output.logits).any()

0 comments on commit 86a99e2

Please sign in to comment.