Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update stream llm to get correct outputs and re-enable rerotated-attention test. #656

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,9 @@ def llama_pos_shift_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
Expand Down Expand Up @@ -103,9 +74,9 @@ def llama_pos_shift_attention_forward(
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
softmax_scale = 1.0 / math.sqrt(self.head_dim)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale
)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
Expand All @@ -114,6 +85,23 @@ def llama_pos_shift_attention_forward(
f" {attn_weights.size()}"
)

# For causal mode, we use to get input mask, but now causal mode does not expect a mask
# and we need to generate the causal mask ourselves.
current_is_causal = False
if self.is_causal and attention_mask is None and q_len > 1:
current_is_causal = True
if current_is_causal and attention_mask is None:
bool_attention_mask = torch.ones(
[query_states.shape[-2], key_states.shape[-2]],
device=query_states.device,
dtype=torch.bool,
).tril()
additive_attention_mask = torch.zeros_like(
bool_attention_mask, dtype=attn_weights.dtype
).masked_fill(bool_attention_mask.logical_not(), -10000)
attn_weights = attn_weights + additive_attention_mask

# Legacy support to take in mask for non-causal mode.
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
Expand All @@ -132,30 +120,10 @@ def llama_pos_shift_attention_forward(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.config.pretraining_tp, dim=2
)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.config.pretraining_tp, dim=1
)
attn_output = sum(
[
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
)
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value


def enable_llama_pos_shift_attention(model):
Expand Down
1 change: 0 additions & 1 deletion models/turbine_models/tests/stateless_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def test_streaming_vmfb_comparison(self):

# See: https://github.com/nod-ai/SHARK-Turbine/issues/560
# Developed issues related to the pytorch 2.3 upgrade.
@unittest.expectedFailure
def test_rerotated_torch_comparison(self):
torch_str = llm_runner.run_torch_llm(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
Expand Down
Loading