Skip to content

Commit

Permalink
[Stateless Llama] StreamingLLM + Add KVCache for prefill stage + inte…
Browse files Browse the repository at this point in the history
…ractive chat mode in llm_runner. (#299)

This PR introduce streamingLLM + KV-Cache at initialization/prefill stage functionality, this
will allow us to generate infinite tokens under controlled memory growh.

This PR also introduce:
1.Set capabilities of GlobalScalars
2.Inheritance of exports/globals for CompiledModule subclasses.
3.READMEs for llm_runner and stateless_llama
4.e2e test refactoring
  • Loading branch information
raikonenfnu authored Jan 5, 2024
1 parent 18e8a41 commit 432fa0d
Show file tree
Hide file tree
Showing 14 changed files with 736 additions and 113 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ _python_build/
dist/
wheelhouse
*.egg-info
*.whl
*.whl

#Model artifacts
*.pt
*.safetensors
*.gguf
*.vmfb
12 changes: 12 additions & 0 deletions python/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None):
continue
del_attr_keys.add(key)
info.def_attribute(key, value)

for key in del_attr_keys:
del dct[key]

Expand All @@ -343,6 +344,17 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None):
if key not in dct:
dct[key] = _blackhole_instance_attribute

# Inheriting methods, globals, and export from parent class.
# Use case such as building a child-class to StatelessLlama.
for base in bases:
if base is CompiledModule:
continue
base_exports = _all_compiled_module_class_infos[base].all_exports
for export_name in base_exports:
if export_name in info.all_exports:
continue
info.all_exports[export_name] = base_exports[export_name]

# Finish construction.
new_class = type.__new__(mcls, name, bases, dct)
_all_compiled_module_class_infos[new_class] = info
Expand Down
4 changes: 4 additions & 0 deletions python/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]):
with proc_trace.loc, proc_trace.ip:
util_d.GlobalStoreOp(ir_values[0], self.symbol_name)

def set(self, other):
t = current_ir_trace()
self.resolve_assignment(t, super().set(other).ir_values)

def __repr__(self):
return (
f"<IrGlobalScalar {self.export_name} = {self.symbol_name}:{self.ir_type}>"
Expand Down
20 changes: 20 additions & 0 deletions python/shark_turbine/aot/support/procedural/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ class IrScalar(Intrinsic):
def __init__(self, ir_type: IrType):
self.ir_type = ir_type

def set(self, other):
t = current_ir_trace()
with t.ip, t.loc:
# Type check and promotion.
# TODO: Add more comprehensive type promotion hiearchy.
lhs = self.ir_value
rhs = None
if isinstance(other, IrScalar):
# Assumes when both are Value, they have same type.
rhs = other.ir_value
elif isinstance(other, (int, bool)) and _is_integer_like_type(self.ir_type):
rhs = arith_d.ConstantOp(lhs.type, other).result
elif isinstance(other, (float)) and _is_float_type(self.ir_type):
rhs = arith_d.ConstantOp(lhs.type, other).result
if rhs is None or lhs.type != rhs.type:
raise ValueError(
f"Cannot handle src type of {self.ir_type} to dst python type of {type(other)}."
)
return IrImmediateScalar(rhs)

def __add__(self, other):
t = current_ir_trace()
with t.ip, t.loc:
Expand Down
39 changes: 39 additions & 0 deletions python/turbine_models/custom_models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Instructions

Clone and install SHARK-Turbine
```
git clone https://github.com/nod-ai/SHARK-Turbine.git
cd SHARK-Turbine
python -m venv turbine_venv && source turbine_venv/bin/activate
pip install --upgrade -r requirements.txt
pip install --upgrade -e .[torch-cpu-nightly,testing]
pip install --upgrade -r turbine-models-requirements.txt
```

## Compiling LLMs
Note: Make sure to replace "your_token" with your actual hf_auth_token for all the commands.

Now, you can generate the quantized weight file with
```
python python/turbine_models/gen_external_params/gen_external_params.py --hf_auth_token=your_token
```
The model weights will then be saved in the current directory as `Llama_2_7b_chat_hf_f16_int4.safetensors`.

To compile to vmfb for llama
```
python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16"
```
By default the vmfb will be saved as `Llama_2_7b_chat_hf.vmfb`.

## Running LLMs
There are two ways of running LLMs:

1) Single run with predefined prompt to validate correctness.
```
python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token
```
2) Interactive CLI chat mode. (just add a --chat_mode flag)
```
python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode
```
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# StreamingLLM

StreamingLLM is based on the paper *"Efficient Streaming Language Models with Attention Sinks"* by Xiao et al from the MIT Han Lab. Here is the original [[paper](http://arxiv.org/abs/2309.17453)] and [[code](https://github.com/mit-han-lab/streaming-llm)].

The modify_llama.py code is highly inspired by the modify_llama.py code in the original repo, but tweaked to work with ToM HuggingFace and compilable through Turbine.

The work introduces sink attention which in short is a combination of a fixed starting few sequence attention along with a sliding window attention. This is beneficial for these reasons:

1) Generate infinitely long context.
2) Maintain memory under certain threshold (controlled by window_length)


## Compiling LLMs with StreamingLLM

Just need to add an extra `--streaming_llm` flag when you call stateless_llama when generating your vmfb. For example:
```
python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" --streaming_llm
```

By default the vmfb will still be saved as `Llama_2_7b_chat_hf.vmfb`.

## Running LLMs with StreamingLLM

Similar to compiling, just need to add an extra `--streaming_llm` flag when you call llm_runner.py. For example:
```
python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode --streaming_llm=true
```

## Future Work:
- [ ] Make window size configurable through python, everything is there but we'd need to initialize with a default value which would only be possible after we let `_create_initial_value` to take in initial value from GlobalAttribute somewhere [here](https://github.com/nod-ai/SHARK-Turbine/blob/18e8a4100b61adfd9425dd32f780dc5f90017813/python/shark_turbine/aot/support/ir_utils.py#L284-L316) .
- [ ] Get flow.move to enable overlap of sliding window and src of data. (Currently need to evict when it's at least 2x size of window) For example by default our streamingLLM window_size is 256, so we evict at ~600(slightly more than 2x for safety) token.
- [ ] Introduce Rerotation of RoPE to as seen [here](https://github.com/huggingface/transformers/blob/c2d283a64a7f33547952e3eb0fa6533fc375bcdd/src/transformers/cache_utils.py#L213-L218) to remove invasive modification of LlamaAttention module for streamingLLM.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import math
from typing import Optional, Tuple

import torch
from torch import nn
import torch.utils.checkpoint

import torch.nn.functional as F

from transformers.models.llama.modeling_llama import (
LlamaAttention,
rotate_half,
apply_rotary_pos_emb,
repeat_kv,
)
import types

__all__ = ["enable_llama_pos_shift_attention"]


def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed


def llama_pos_shift_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
### Shift Pos: query pos is min(cache_size, idx)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
###

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

### Shift Pos: key pos is the pos in cache
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
###

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.config.pretraining_tp, dim=2
)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.config.pretraining_tp, dim=1
)
attn_output = sum(
[
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
)
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


def enable_llama_pos_shift_attention(model):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
enable_llama_pos_shift_attention(
module,
)

if isinstance(module, LlamaAttention):
model._modules[name].forward = types.MethodType(
llama_pos_shift_attention_forward, model._modules[name]
)
Loading

0 comments on commit 432fa0d

Please sign in to comment.