Skip to content

Commit

Permalink
添加workflow并且ruff格式化
Browse files Browse the repository at this point in the history
  • Loading branch information
wangmengdi committed May 16, 2024
1 parent 64910fa commit 814d1c6
Show file tree
Hide file tree
Showing 27 changed files with 185 additions and 502 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
20 changes: 5 additions & 15 deletions osc_llm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def download_huggingface_model(
try:
result = safetensors_load(safetensor_path)
except SafetensorError as e:
raise RuntimeError(
f"{safetensor_path} is likely corrupted. Please try to re-download it."
) from e
raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e
msg.info(f"{safetensor_path} --> {bin_path}")
torch.save(result, bin_path)
os.remove(safetensor_path)
Expand All @@ -91,9 +89,7 @@ def get_hf_model_helper(checkpoint_dir: str) -> HFModelHelper:
text=f"Supported architectures are: {allowed_architectures}",
exits=1,
)
model_helper: HFModelHelper = registry.model_helpers.get(architecture)(
checkpoint_dir
)
model_helper: HFModelHelper = registry.model_helpers.get(architecture)(checkpoint_dir)
return model_helper


Expand Down Expand Up @@ -124,9 +120,7 @@ def quantize_int8(checkpoint_dir: str, save_dir: str):
if not save_dir.exists():
save_dir.mkdir(parents=True)
tokenizer = Tokenizer(checkpoint_dir=checkpoint_dir)
model, config = build_from_checkpoint(
checkpoint_dir=checkpoint_dir, return_config=True
)
model, config = build_from_checkpoint(checkpoint_dir=checkpoint_dir, return_config=True)
quantizer = Int8Quantizer()
model = quantizer.quantize(model)
config = config.merge(quantizer.quantizer_config)
Expand Down Expand Up @@ -160,13 +154,9 @@ def quantize_int4(
if not Path(save_dir).exists():
Path(save_dir).mkdir(parents=True)
tokenizer = Tokenizer(checkpoint_dir=checkpoint_dir)
model, config = build_from_checkpoint(
checkpoint_dir=checkpoint_dir, return_config=True
)
model, config = build_from_checkpoint(checkpoint_dir=checkpoint_dir, return_config=True)
model.to(device)
quantizer = WeightOnlyInt4Quantizer(
groupsize=groupsize, inner_k_tiles=k, padding_allowed=padding
)
quantizer = WeightOnlyInt4Quantizer(groupsize=groupsize, inner_k_tiles=k, padding_allowed=padding)
model = quantizer.quantize(model)
config = config.merge(quantizer.quantizer_config)
torch.save(model.state_dict(), Path(save_dir) / "osc_model.pth")
Expand Down
39 changes: 9 additions & 30 deletions osc_llm/architectures/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def kv_caches(self) -> List[KVCache]:

@kv_caches.setter
def kv_caches(self, value: List[KVCache]):
assert len(value) == len(
self.blocks
), "Number of kv_caches must match number of blocks"
assert len(value) == len(self.blocks), "Number of kv_caches must match number of blocks"
for block, kv_cache in zip(self.blocks, value):
block.attention.kv_cache = kv_cache

Expand Down Expand Up @@ -139,18 +137,14 @@ def setup_kv_cache(
dtype: Optional[torch.dtype] = None,
):
if kv_cache:
assert isinstance(
kv_cache, KVCache
), "kv_cache must be an instance of KVCache"
assert isinstance(kv_cache, KVCache), "kv_cache must be an instance of KVCache"
else:
kv_cache = StaticKVCache()
self.kv_caches = [deepcopy(kv_cache) for _ in range(self.n_blocks)]
if not max_length:
max_length = self.block_size
else:
assert (
max_length <= self.block_size
), "max_length must be less than or equal to block_size"
assert max_length <= self.block_size, "max_length must be less than or equal to block_size"

for block in self.blocks:
block.attention.setup_kv_cache(
Expand All @@ -161,16 +155,10 @@ def setup_kv_cache(
)

self.mask_cache = (
torch.tril(
torch.ones((max_length, max_length), device=device, dtype=torch.bool)
)
.unsqueeze(0)
.unsqueeze(0)
torch.tril(torch.ones((max_length, max_length), device=device, dtype=torch.bool)).unsqueeze(0).unsqueeze(0)
)

def setup_rope_cache(
self, max_length: int, device: Optional[torch.device] = None
) -> None:
def setup_rope_cache(self, max_length: int, device: Optional[torch.device] = None) -> None:
head_size = self.blocks[0].attention.head_size
cos, sin = build_rope_cache(
seq_len=max_length,
Expand All @@ -182,9 +170,7 @@ def setup_rope_cache(
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)

def forward(
self, input_ids: torch.Tensor, input_pos: Optional[torch.Tensor] = None
):
def forward(self, input_ids: torch.Tensor, input_pos: Optional[torch.Tensor] = None):
"""Forward pass of the TransformerDecoder.
Args:
Expand All @@ -195,9 +181,7 @@ def forward(
B, L = input_ids.size()

if self.max_length < L:
raise ValueError(
f"Cannot forward sequence of length {L}, max seq length is only {self.max_seq_length}."
)
raise ValueError(f"Cannot forward sequence of length {L}, max seq length is only {self.max_seq_length}.")

if input_pos is not None:
# use rope cache
Expand All @@ -224,9 +208,7 @@ def forward(

return x

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = True
):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = True):
# 保证在用torch.device('meta')构建模型后, 可以运行model.to('cuda:xxx'),不然会由于cos和sin是meta data而报错
self.setup_rope_cache(max_length=self.max_length)
return super().load_state_dict(state_dict, strict, assign)
Expand All @@ -247,10 +229,7 @@ def model_size(self, include_embeddings: bool = True) -> int:
if n == "embedding" and not include_embeddings:
continue
model_size += sum(
[
p.numel() * p.dtype.itemsize
for p in itertools.chain(children.parameters(), children.buffers())
]
[p.numel() * p.dtype.itemsize for p in itertools.chain(children.parameters(), children.buffers())]
)
return model_size

Expand Down
20 changes: 6 additions & 14 deletions osc_llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def main(
if not hasattr(engine, "decode_model"):
model_size = engine.model.model_size(include_embeddings=False)
else:
model_size = engine.decode_model.model_size(
model_size = engine.decode_model.model_size(include_embeddings=False) + engine.prefill_model.model_size(
include_embeddings=False
) + engine.prefill_model.model_size(include_embeddings=False)
)

if compile:
t = time.perf_counter()
Expand All @@ -74,9 +74,7 @@ def main(
engine.fabric.print("\n")

messages = []
pre_ids_len = (
0 # 多轮对话过程中,对之前的对话历史做一个缓存,这样避免在prefill阶段重新kv cache
)
pre_ids_len = 0 # 多轮对话过程中,对之前的对话历史做一个缓存,这样避免在prefill阶段重新kv cache
while True:
content = input("User (empty to exit): ")
if content == "":
Expand All @@ -88,9 +86,7 @@ def main(
input_pos = torch.arange(pre_ids_len, len(input_ids))
input_ids = input_ids[pre_ids_len:]

stream = engine.run(
input_ids=input_ids, stop_ids=tokenizer.stop_ids, input_pos=input_pos
)
stream = engine.run(input_ids=input_ids, stop_ids=tokenizer.stop_ids, input_pos=input_pos)
generated_text = ""
engine.fabric.print("Assistant: ")
time0 = time.perf_counter()
Expand All @@ -114,10 +110,6 @@ def main(
engine.fabric.print(
f"Generated {num_new_tokens} tokens in {t:.02f} seconds, {(num_new_tokens / t):.2f} tokens/second"
)
engine.fabric.print(
f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)
engine.fabric.print(
f"Bandwidth Achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
)
engine.fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
engine.fabric.print(f"Bandwidth Achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
engine.fabric.print("\n")
8 changes: 2 additions & 6 deletions osc_llm/chat_templates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ class ChatTemplate:
generate_prompt: str = ""

@classmethod
def apply_messages(
cls, messages: List[Message], add_generate_prompt: bool = True
) -> str:
def apply_messages(cls, messages: List[Message], add_generate_prompt: bool = True) -> str:
raise NotImplementedError

@classmethod
Expand Down Expand Up @@ -75,7 +73,5 @@ def from_checkpoint(cls, checkpoint_dir: str) -> "ChatTemplate":
if config_path.exists():
config = Config().from_disk(config_path)
if "chat_template" in config:
return registry.chat_templates.get(
config["chat_template"]["@chat_templates"]
)
return registry.chat_templates.get(config["chat_template"]["@chat_templates"])
return cls.from_name(checkpoint_dir.stem)
4 changes: 1 addition & 3 deletions osc_llm/chat_templates/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ class ChatGLM3ChatTemplate(ChatTemplate):
stop_texts: List[str] = ["<|observation|>", "<|user|>"]

@classmethod
def apply_messages(
cls, messages: List[Message], add_generate_prompt: bool = True
) -> str:
def apply_messages(cls, messages: List[Message], add_generate_prompt: bool = True) -> str:
prompt: str = ""
for message in messages:
if message.role == "user":
Expand Down
4 changes: 1 addition & 3 deletions osc_llm/chat_templates/chatml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ class ChatMLChatTemplate(ChatTemplate):
generate_prompt: str = "<|im_start|>assistant\n"

@classmethod
def apply_messages(
cls, messages: List[Message], add_generate_prompt: bool = True
) -> str:
def apply_messages(cls, messages: List[Message], add_generate_prompt: bool = True) -> str:
prompt = ""
for message in messages:
if message.role == "user":
Expand Down
8 changes: 2 additions & 6 deletions osc_llm/chat_templates/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ class Llama3ChatTemplate(ChatTemplate):
generate_prompt: str = "<|start_header_id|>assistant<|end_header_id|>\n\n"

@classmethod
def apply_messages(
cls, messages: List[Message], add_generate_prompt: bool = False
) -> str:
def apply_messages(cls, messages: List[Message], add_generate_prompt: bool = False) -> str:
assert messages[-1].role == "user", "Last message must be user"
prompt = "<|begin_of_text|>"
for message in messages:
Expand Down Expand Up @@ -42,9 +40,7 @@ class Llama2ChatTemplate(ChatTemplate):
generate_prompt: str = ""

@classmethod
def apply_messages(
cls, messages: List[Message], add_generate_prompt: bool = True
) -> str:
def apply_messages(cls, messages: List[Message], add_generate_prompt: bool = True) -> str:
if messages[0].role == "system":
assert len(messages) >= 2, "must have a user input"
assert messages[1].role == "user", "must have a user input"
Expand Down
12 changes: 3 additions & 9 deletions osc_llm/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def __init__(
):
if not precision:
precision = get_default_supported_precision(training=False)
self.fabric = Fabric(
devices=devices, accelerator=accelerator, precision=precision
)
self.fabric = Fabric(devices=devices, accelerator=accelerator, precision=precision)

self.sampler = sampler if sampler else TopK(temperature=0.8, k=200)
self.max_length = max_length
Expand All @@ -53,16 +51,12 @@ def run(self, **model_inputs) -> Generator[torch.Tensor, None, None]:
def setup(self) -> None:
t = perf_counter()
self.load_model()
self.fabric.print(
f"load model in {perf_counter() - t:.02f} seconds", file=sys.stderr
)
self.fabric.print(f"load model in {perf_counter() - t:.02f} seconds", file=sys.stderr)
if self.compile:
self.compile_model()
t = perf_counter()
self.setup_model()
self.fabric.print(
f"setup model in {perf_counter() - t:.02f} seconds", file=sys.stderr
)
self.fabric.print(f"setup model in {perf_counter() - t:.02f} seconds", file=sys.stderr)

def reset_sampler(self, sampler: Sampler) -> None:
self.sampler = sampler
14 changes: 5 additions & 9 deletions osc_llm/engines/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,21 @@ def load_model(self) -> None:
self.model = self.fabric.to_device(self.model)

with self.fabric.init_tensor():
self.model.setup_kv_cache(
batch_size=1, max_length=self.max_length, dtype=torch.bfloat16
)
self.model.setup_kv_cache(batch_size=1, max_length=self.max_length, dtype=torch.bfloat16)

def compile_model(self) -> None:
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_trees = (
False # 目前用作server的时候有bug
torch._inductor.config.fx_graph_cache = (
True # Experimental feature to reduce compilation times, will be on by default in future
)
torch._inductor.config.triton.cudagraph_trees = False # 目前用作server的时候有bug

torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

self.model: TransformerDecoder = torch.compile(
self.model, dynamic=True, fullgraph=True, mode="reduce-overhead"
)
self.model: TransformerDecoder = torch.compile(self.model, dynamic=True, fullgraph=True, mode="reduce-overhead")

def setup_model(self) -> None:
self.model = self.fabric.setup_module(self.model)
Expand Down
24 changes: 8 additions & 16 deletions osc_llm/engines/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,28 @@ def load_model(self) -> None:
), "Only TransformerDecoder Architecture is supported"

with self.fabric.init_module(empty_init=True):
self.prefill_model = build_model(
config=config_path, empty_init=False
).eval()
self.prefill_model = build_model(config=config_path, empty_init=False).eval()
self.decode_model = build_model(config=config_path, empty_init=False).eval()

self.fabric.load_raw(states_path, self.prefill_model)
self.fabric.load_raw(states_path, self.decode_model)

with self.fabric.init_tensor():
self.prefill_model.setup_kv_cache(
batch_size=1, max_length=self.max_length, dtype=torch.bfloat16
)
self.prefill_model.setup_kv_cache(batch_size=1, max_length=self.max_length, dtype=torch.bfloat16)

self.decode_model.kv_caches = self.prefill_model.kv_caches
self.decode_model.mask_cache = self.prefill_model.mask_cache

def compile_model(self) -> None:
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = (
True # Experimental feature to reduce compilation times, will be on by default in future
)
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.suppress_errors = True
self.decode_model: TransformerDecoder = torch.compile(
self.decode_model, mode="reduce-overhead", fullgraph=True
)
self.prefill_model: TransformerDecoder = torch.compile(
self.prefill_model, dynamic=True
)
self.decode_model: TransformerDecoder = torch.compile(self.decode_model, mode="reduce-overhead", fullgraph=True)
self.prefill_model: TransformerDecoder = torch.compile(self.prefill_model, dynamic=True)

def setup_model(self) -> None:
self.prefill_model = self.fabric.setup_module(self.prefill_model)
Expand All @@ -79,9 +73,7 @@ def run(
stop_ids = [self.fabric.to_device(stop_id) for stop_id in stop_ids]

# prefill
max_length = (
self.max_length if self.max_length else self.prefill_model.block_size
)
max_length = self.max_length if self.max_length else self.prefill_model.block_size
input_ids = self.prefill(input_ids=input_ids.view(1, -1), input_pos=input_pos)
yield input_ids

Expand Down
Loading

0 comments on commit 814d1c6

Please sign in to comment.