Skip to content

Commit

Permalink
Update stream llm to get correct outputs and re-enable rerotated-atte…
Browse files Browse the repository at this point in the history
…ntion test. (#656)

During the update of pytorch/HF, there seem to be a change of how causal
mask was being handled. It seems like the attention.forward function
used to get a `causal_mask` from the argument as `attention_mask` when
is_causal is on. Now it seems like we would need to construct our own
mask when `is_causal` is true. This was causing numerical issues in this
test as well as on Llama2 qualitatively.

This PR introduces construction of causal mask, as well as removing
unnecessary tensor parallel config checks which simplifies the code
quite a bit.
  • Loading branch information
raikonenfnu authored Apr 25, 2024
1 parent 7877444 commit 4a01c40
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,9 @@ def llama_pos_shift_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
Expand Down Expand Up @@ -103,9 +74,9 @@ def llama_pos_shift_attention_forward(
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
softmax_scale = 1.0 / math.sqrt(self.head_dim)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale
)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
Expand All @@ -114,6 +85,23 @@ def llama_pos_shift_attention_forward(
f" {attn_weights.size()}"
)

# For causal mode, we use to get input mask, but now causal mode does not expect a mask
# and we need to generate the causal mask ourselves.
current_is_causal = False
if self.is_causal and attention_mask is None and q_len > 1:
current_is_causal = True
if current_is_causal and attention_mask is None:
bool_attention_mask = torch.ones(
[query_states.shape[-2], key_states.shape[-2]],
device=query_states.device,
dtype=torch.bool,
).tril()
additive_attention_mask = torch.zeros_like(
bool_attention_mask, dtype=attn_weights.dtype
).masked_fill(bool_attention_mask.logical_not(), -10000)
attn_weights = attn_weights + additive_attention_mask

# Legacy support to take in mask for non-causal mode.
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
Expand All @@ -132,30 +120,10 @@ def llama_pos_shift_attention_forward(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.config.pretraining_tp, dim=2
)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.config.pretraining_tp, dim=1
)
attn_output = sum(
[
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
)
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value


def enable_llama_pos_shift_attention(model):
Expand Down
1 change: 0 additions & 1 deletion models/turbine_models/tests/stateless_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def test_streaming_vmfb_comparison(self):

# See: https://github.com/nod-ai/SHARK-Turbine/issues/560
# Developed issues related to the pytorch 2.3 upgrade.
@unittest.expectedFailure
def test_rerotated_torch_comparison(self):
torch_str = llm_runner.run_torch_llm(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
Expand Down

0 comments on commit 4a01c40

Please sign in to comment.