From d2fbc3b47a39955872f4c03afd20122e449e28ec Mon Sep 17 00:00:00 2001 From: Sasha Doubov Date: Fri, 11 Aug 2023 10:35:21 -0700 Subject: [PATCH] Grouped Query Attention + Refactor Attn (#492) Adds support for GQA, and refactors MHA and MQA as special cases of GQA. --------- Co-authored-by: root Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/layers/attention.py | 288 ++++++++++++--------- llmfoundry/models/layers/blocks.py | 30 ++- llmfoundry/models/mpt/configuration_mpt.py | 4 +- tests/test_flash_triton_torch.py | 120 ++++++++- tests/test_model.py | 6 +- 5 files changed, 293 insertions(+), 155 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 34692e600b..c99e81875b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -36,6 +36,7 @@ def scaled_multihead_dot_product_attention( key: torch.Tensor, value: torch.Tensor, n_heads: int, + kv_n_heads: Optional[int] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, @@ -47,8 +48,21 @@ def scaled_multihead_dot_product_attention( multiquery: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: + + if multiquery: + warnings.warn( + DeprecationWarning( + 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.' + )) + kv_n_heads = 1 + elif kv_n_heads is None: + warnings.warn( + DeprecationWarning( + 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.' + )) + kv_n_heads = n_heads + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) - kv_n_heads = 1 if multiquery else n_heads k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) @@ -68,6 +82,11 @@ def scaled_multihead_dot_product_attention( b, _, s_q, d = q.shape s_k = k.size(-1) + # grouped query case + if kv_n_heads > 1 and kv_n_heads < n_heads: + k = k.repeat_interleave(n_heads // kv_n_heads, dim=1) + v = v.repeat_interleave(n_heads // kv_n_heads, dim=1) + if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) @@ -143,6 +162,7 @@ def flash_attn_fn( key: torch.Tensor, value: torch.Tensor, n_heads: int, + kv_n_heads: Optional[int] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, @@ -161,6 +181,19 @@ def flash_attn_fn( check_valid_inputs(query, key, value) + if multiquery: + warnings.warn( + DeprecationWarning( + 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.' + )) + kv_n_heads = 1 + elif kv_n_heads is None: + warnings.warn( + DeprecationWarning( + 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.' + )) + kv_n_heads = n_heads + if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) @@ -189,16 +222,13 @@ def flash_attn_fn( key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input( key, key_padding_mask) - key_unpad = rearrange(key_unpad, - 'nnz (h d) -> nnz h d', - h=1 if multiquery else n_heads) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange(value_unpad, - 'nnz (h d) -> nnz h d', - h=1 if multiquery else n_heads) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - if multiquery: + # multi-query case + if kv_n_heads == 1: # Expanding a tensor does not allocate new memory, but only creates a new # view on the existing tensor where a dimension of size one is expanded # to a larger size by setting the stride to 0. @@ -209,6 +239,14 @@ def flash_attn_fn( key_unpad.size(-1)) value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) + # grouped query case + elif kv_n_heads < n_heads: + # Each query belong to a group of kv heads of group size n_heads // kv_n_heads + # We repeat each kv head by the group size number to use use the underlying MHA kernels + # done along the head dimension = 1 + key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1) + value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads, + dim=1) dropout_p = dropout_p if training else 0.0 @@ -238,6 +276,7 @@ def triton_flash_attn_fn( key: torch.Tensor, value: torch.Tensor, n_heads: int, + kv_n_heads: Optional[int] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, @@ -278,6 +317,19 @@ def triton_flash_attn_fn( check_valid_inputs(query, key, value) + if multiquery: + warnings.warn( + DeprecationWarning( + 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.' + )) + kv_n_heads = 1 + elif kv_n_heads is None: + warnings.warn( + DeprecationWarning( + 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.' + )) + kv_n_heads = n_heads + if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) @@ -318,16 +370,22 @@ def triton_flash_attn_fn( torch.finfo(query.dtype).min) query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) - key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) - value = rearrange(value, - 'b s (h d) -> b s h d', - h=1 if multiquery else n_heads) + key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads) + value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads) - if multiquery: + # multi-query case + if kv_n_heads == 1: # necessary to repeat instead of expand tensor because # output contains NaN in edge cases such as with head dimension = 8 key = key.repeat(1, 1, n_heads, 1) value = value.repeat(1, 1, n_heads, 1) + # grouped query case + elif kv_n_heads < n_heads: + # Each query belong to a group of kv heads of group size n_heads // kv_n_heads + # We repeat each kv head by the group size number to use use the underlying MHA kernels + # done along dim = 2, unlike the implementation for flash and torch attn + key = key.repeat_interleave(n_heads // kv_n_heads, dim=2) + value = value.repeat_interleave(n_heads // kv_n_heads, dim=2) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func( # type: ignore @@ -338,17 +396,21 @@ def triton_flash_attn_fn( return output, None, past_key_value -class MultiheadAttention(nn.Module): - """Multi-head self attention. +class GroupedQueryAttention(nn.Module): + """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). - Using torch or triton attention implementation enables user to also use - additive bias. + and Multi-query attention (MQA). + + This allows the user to set a variable of number of kv_n_heads, rather than + just n_heads or 1, as in MHA and MQA. Using torch or triton attention + implementation enables user to also use additive bias. """ def __init__( self, d_model: int, n_heads: int, + kv_n_heads: int, attn_impl: str = 'triton', clip_qkv: Optional[float] = None, qk_ln: bool = False, @@ -367,6 +429,23 @@ def __init__( self.d_model = d_model self.n_heads = n_heads + self.kv_n_heads = kv_n_heads + + self.head_dim = d_model // n_heads + + if self.kv_n_heads <= 0: + raise ValueError('kv_n_heads should be greater than zero.') + + if self.kv_n_heads > self.n_heads: + raise ValueError( + 'The number of KV heads should be less than or equal to Q heads.' + ) + + if self.n_heads % self.kv_n_heads != 0: + raise ValueError( + 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.' + ) + self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) @@ -377,17 +456,21 @@ def __init__( fc_kwargs['device'] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( self.d_model, - 3 * self.d_model, + self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs, ) # for param init fn; enables shape based init of fused layers - fuse_splits = (d_model, 2 * d_model) + fuse_splits = [ + i * self.head_dim + for i in range(1, self.n_heads + 2 * self.kv_n_heads) + ] self.Wqkv._fused = (0, fuse_splits) # type: ignore if self.qk_ln: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] self.q_ln = norm_class(self.d_model, device=device) - self.k_ln = norm_class(self.d_model, device=device) + self.k_ln = norm_class(self.kv_n_heads * self.head_dim, + device=device) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn @@ -432,7 +515,14 @@ def forward( if self.clip_qkv: qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) - query, key, value = qkv.chunk(3, dim=2) + query, key, value = qkv.split( + [ + self.d_model, + self.kv_n_heads * self.head_dim, + self.kv_n_heads * self.head_dim, + ], + dim=2, + ) key_padding_mask = attention_mask @@ -447,6 +537,7 @@ def forward( key, value, self.n_heads, + self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, @@ -460,8 +551,8 @@ def forward( return self.out_proj(context), attn_weights, past_key_value -class MultiQueryAttention(nn.Module): - """Multi-Query self attention. +class MultiheadAttention(GroupedQueryAttention): + """Multi-head self attention. Using torch or triton attention implementation enables user to also use additive bias. @@ -481,113 +572,55 @@ def __init__( verbose: int = 0, device: Optional[str] = None, ): - super().__init__() - - self.attn_impl = attn_impl - self.clip_qkv = clip_qkv - self.qk_ln = qk_ln - - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - self.softmax_scale = softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = attn_pdrop - - fc_kwargs = {} - if fc_type != 'te': - fc_kwargs['device'] = device - # NOTE: if we ever want to make attn TensorParallel, I'm pretty sure we'll - # want to split Wqkv into Wq and Wkv where Wq can be TensorParallel but - # Wkv shouldn't be TensorParallel - # - vchiley - self.Wqkv = FC_CLASS_REGISTRY[fc_type]( - d_model, - d_model + 2 * self.head_dim, - **fc_kwargs, - ) - # for param init fn; enables shape based init of fused layers - fuse_splits = (d_model, d_model + self.head_dim) - self.Wqkv._fused = (0, fuse_splits) # type: ignore - - if self.qk_ln: - norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] - self.q_ln = norm_class(d_model, device=device) - self.k_ln = norm_class(self.head_dim, device=device) - - if self.attn_impl == 'flash': - self.attn_fn = flash_attn_fn - elif self.attn_impl == 'triton': - self.attn_fn = triton_flash_attn_fn - if verbose: - warnings.warn( - 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\ - 'it uses more memory. When training larger models this can trigger ' +\ - 'alloc retries which hurts performance. If encountered, we recommend ' +\ - 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.' - ) - elif self.attn_impl == 'torch': - self.attn_fn = scaled_multihead_dot_product_attention - if torch.cuda.is_available() and verbose: - warnings.warn( - 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\ - '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\ - 'we recommend using `attn_impl: triton`.' - ) - else: - raise ValueError(f'{attn_impl=} is an invalid setting.') + super().__init__( + d_model=d_model, + n_heads=n_heads, + kv_n_heads=n_heads, # for MHA, same # heads as kv groups + attn_impl=attn_impl, + clip_qkv=clip_qkv, + qk_ln=qk_ln, + softmax_scale=softmax_scale, + attn_pdrop=attn_pdrop, + norm_type=norm_type, + fc_type=fc_type, + verbose=verbose, + device=device) + + +class MultiQueryAttention(GroupedQueryAttention): + """Multi-Query self attention. - self.out_proj = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model, - **fc_kwargs, - ) - self.out_proj._is_residual = True # type: ignore + Using torch or triton attention implementation enables user to also use + additive bias. + """ - def forward( + def __init__( self, - x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attn_bias: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - is_causal: bool = True, - needs_weights: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ - torch.Tensor, torch.Tensor]]]: - qkv = self.Wqkv(x) - - if self.clip_qkv: - qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) - - query, key, value = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2) - - key_padding_mask = attention_mask - - if self.qk_ln: - # Applying layernorm to qk - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - - context, attn_weights, past_key_value = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - multiquery=True, - ) - - return self.out_proj(context), attn_weights, past_key_value + d_model: int, + n_heads: int, + attn_impl: str = 'triton', + clip_qkv: Optional[float] = None, + qk_ln: bool = False, + softmax_scale: Optional[float] = None, + attn_pdrop: float = 0.0, + norm_type: str = 'low_precision_layernorm', + fc_type: str = 'torch', + verbose: int = 0, + device: Optional[str] = None, + ): + super().__init__( + d_model=d_model, + n_heads=n_heads, + kv_n_heads=1, # for MQA, 1 head + attn_impl=attn_impl, + clip_qkv=clip_qkv, + qk_ln=qk_ln, + softmax_scale=softmax_scale, + attn_pdrop=attn_pdrop, + norm_type=norm_type, + fc_type=fc_type, + verbose=verbose, + device=device) def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, @@ -678,4 +711,5 @@ def build_alibi_bias( ATTN_CLASS_REGISTRY = { 'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention, + 'grouped_query_attention': GroupedQueryAttention } diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index cec14e5d2a..b5a3ff8d68 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -55,20 +55,24 @@ def __init__( assert isinstance(attn_config['attn_type'], str) attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] + # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs + args_to_exclude_in_attn_class = { + 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', + 'alibi_bias_max' + } + attn_config_subset_for_attn_class = { + k: v + for k, v in attn_config.items() + if k not in args_to_exclude_in_attn_class + } + self.norm_1 = norm_class(d_model, device=device) - self.attn = attn_class( - d_model=d_model, - n_heads=n_heads, - attn_impl=attn_config['attn_impl'], - clip_qkv=attn_config['clip_qkv'], - qk_ln=attn_config['qk_ln'], - softmax_scale=attn_config['softmax_scale'], - attn_pdrop=attn_config['attn_pdrop'], - norm_type=norm_type, - fc_type=fc_type, - verbose=verbose, - device=device, - ) + self.attn = attn_class(d_model=d_model, + n_heads=n_heads, + fc_type=fc_type, + verbose=verbose, + device=device, + **attn_config_subset_for_attn_class) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 693d19f898..08c02fa3b1 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -77,7 +77,7 @@ def __init__( emb_pdrop (float): The dropout probability for the embedding layer. learned_pos_emb (bool): Whether to use learned positional embeddings attn_config (Dict): A dictionary used to configure the model's attention module: - attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention + attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention attn_pdrop (float): The dropout probability for the attention layers. attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. @@ -94,6 +94,7 @@ def __init__( Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp init_device (str): The device to use for parameter initialization. @@ -102,7 +103,6 @@ def __init__( verbose (int): The verbosity level. 0 is silent. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use - multiquery_attention (bool): Whether to use multiquery attention implementation. use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index d029f4fe4d..145d4a5885 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -20,13 +20,15 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize('qk_ln', [True, False]) @pytest.mark.parametrize('alibi', [True, False]) -@pytest.mark.parametrize('multiquery', [True, False]) +@pytest.mark.parametrize( + 'attn_type', + ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, alibi: bool, - multiquery: bool, + attn_type: str, device: str = 'cuda'): """Compare all attn impl with each other. @@ -42,7 +44,7 @@ def test_attn_impl(attn_impl_0: str, cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, - 'n_heads': 2, + 'n_heads': 4, 'attn_pdrop': 0, 'clip_qkv': clip_qkv, 'qk_ln': qk_ln, @@ -50,16 +52,12 @@ def test_attn_impl(attn_impl_0: str, n, s, f = 2, 16, cfg.d_model + if attn_type == 'grouped_query_attention': + cfg.kv_n_heads = 2 + cfg.attn_impl = attn_impl_0 - if multiquery: - attn0 = attention.MultiQueryAttention(**cfg).to(device) - else: - attn0 = attention.MultiheadAttention(**cfg).to(device) - cfg.attn_impl = attn_impl_1 - if multiquery: - attn1 = attention.MultiQueryAttention(**cfg).to(device) - else: - attn1 = attention.MultiheadAttention(**cfg).to(device) + attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) attn1.load_state_dict(attn0.state_dict()) @@ -223,3 +221,101 @@ def gen_tca_mask(): assert allclose_helper(tmhsa.in_proj_weight.grad, mmhsa.Wqkv.weight.grad) assert allclose_helper(x0.grad, x1.grad) + + +@pytest.mark.gpu +@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch']) +@pytest.mark.parametrize('n_heads', [32, 16, 8]) +@pytest.mark.parametrize('kv_n_heads', [8, 4, 2, 1]) +def test_grouped_attention_heads(attn_impl: str, + n_heads: int, + kv_n_heads: int, + device: str = 'cuda'): + """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" + from llmfoundry.models.layers import attention + + reproducibility.seed_all(17) + + cfg = om.create({ + 'attn_impl': attn_impl, + 'd_model': 256, + 'n_heads': n_heads, + 'attn_pdrop': 0, + 'clip_qkv': False, + 'qk_ln': False, + 'kv_n_heads': kv_n_heads + }) + + n, s, f = 2, 16, cfg.d_model + + mmhsa = attention.GroupedQueryAttention(**cfg).to(device) + + attention_mask = torch.ones(n, s).to(device).bool() + x0 = torch.randn(n, s, f).to(device) + x0.requires_grad = True + + with torch.autocast(x0.device.type): + y0, _, _ = mmhsa(x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + is_causal=True) + y0 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + + loss0.backward() + + +@pytest.mark.gpu +@pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch']) +def test_grouped_query_invalid_heads(attn_impl: str, device: str = 'cuda'): + """Check indivisble combinations of grouped_query_attention.""" + from llmfoundry.models.layers import attention + + reproducibility.seed_all(17) + + cfg = om.create({ + 'attn_impl': attn_impl, + 'd_model': 256, + 'n_heads': 16, + 'attn_pdrop': 0, + 'clip_qkv': False, + 'qk_ln': False, + 'kv_n_heads': 3 + }) + + expected_error = 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads' + + with pytest.raises(ValueError, match=expected_error): + _ = attention.GroupedQueryAttention(**cfg).to(device) + + cfg = om.create({ + 'attn_impl': attn_impl, + 'd_model': 256, + 'n_heads': 16, + 'attn_pdrop': 0, + 'clip_qkv': False, + 'qk_ln': False, + 'kv_n_heads': 17 + }) + + expected_error = 'The number of KV heads should be less than or equal to Q heads' + + with pytest.raises(ValueError, match=expected_error): + _ = attention.GroupedQueryAttention(**cfg).to(device) + + cfg = om.create({ + 'attn_impl': attn_impl, + 'd_model': 256, + 'n_heads': 16, + 'attn_pdrop': 0, + 'clip_qkv': False, + 'qk_ln': False, + 'kv_n_heads': 0 + }) + + expected_error = 'kv_n_heads should be greater than zero' + + with pytest.raises(ValueError, match=expected_error): + _ = attention.GroupedQueryAttention(**cfg).to(device) diff --git a/tests/test_model.py b/tests/test_model.py index 35d34e3626..f66b5382c3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -373,7 +373,7 @@ def test_loss_fn(): assert isinstance(test_cfg, DictConfig) test_cfg.device = 'cuda:0' - test_cfg.model.init_device = 'cuda:0' + test_cfg.model.init_device = 'cpu' test_cfg.model.init_config = { 'name': 'baseline_', 'init_std': 0.02, @@ -386,6 +386,10 @@ def test_loss_fn(): model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, tokenizer) model_2 = copy.deepcopy(model_1) + + model_1.to(test_cfg.device) + model_2.to(test_cfg.device) + assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)