Skip to content

Commit

Permalink
Falcon graph compilation error fix for when bs>1
Browse files Browse the repository at this point in the history
  • Loading branch information
schoi-habana committed Dec 20, 2023
1 parent 21238af commit 47893cc
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,25 @@ def gaudi_falcon_rotary_embedding_forward(self, query, key, seq_len, position_id
"""
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)

query_expansion_factor = int(query.shape[0] / cos.shape[0])
if query_expansion_factor > 1:
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
else:
query_cos, query_sin = cos, sin

key_expansion_factor = int(key.shape[0] / cos.shape[0])
if key_expansion_factor > 1:
if key_expansion_factor != query_expansion_factor:
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
else:
key_cos, key_sin = query_cos, query_sin
else:
key_cos, key_sin = cos, sin

if FusedRoPE:
return FusedRoPE.apply(query, cos, sin, 0), FusedRoPE.apply(key, cos, sin, 0)
return FusedRoPE.apply(query, query_cos, query_sin, 0), FusedRoPE.apply(key, key_cos, key_sin, 0)
else:
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)

Expand Down

0 comments on commit 47893cc

Please sign in to comment.