-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path0001-Add-events-timing-to-MLCChat.patch
145 lines (136 loc) · 6.48 KB
/
0001-Add-events-timing-to-MLCChat.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
From c9950224e21153b59d0e610ab06bd5bfedf98a26 Mon Sep 17 00:00:00 2001
From: Stefanos Laskaridis <[email protected]>
Date: Sun, 3 Mar 2024 17:18:06 +0000
Subject: [PATCH] Add events timing to MLCChat++
---
python/mlc_chat/chat_module.py | 24 +++++++++++++++++++++++-
python/mlc_chat/cli/chat.py | 7 +++++++
python/mlc_chat/interface/chat.py | 6 +++++-
3 files changed, 35 insertions(+), 2 deletions(-)
diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py
index 62ca0135..79de7756 100644
--- a/python/mlc_chat/chat_module.py
+++ b/python/mlc_chat/chat_module.py
@@ -7,6 +7,7 @@ import json
import os
import subprocess
import sys
+import time
import warnings
from dataclasses import asdict, dataclass, fields
from enum import Enum
@@ -719,6 +720,9 @@ class ChatModule: # pylint: disable=too-many-instance-attributes
device_type = self.device.device_type
device_id = self.device.device_id
+ self.energy_events = {}
+ self.generate_counter = 0
+
# 1. Populate chat module and their functions
fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create")
assert fcreate_chat_mod is not None
@@ -844,23 +848,35 @@ class ChatModule: # pylint: disable=too-many-instance-attributes
num_return_sequences = generation_config.n
return_str = False
- for _ in range(num_return_sequences):
+ for idx in range(num_return_sequences):
if stateless:
self.reset_chat()
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.prefill.start"] = time.time_ns()
self._prefill(prompt, generation_config=generation_config)
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.prefill.end"] = time.time_ns()
if not progress_callback:
+ decode_counter = 0
while not self._stopped():
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.start"] = time.time_ns()
self._decode(generation_config=generation_config)
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.end"] = time.time_ns()
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.start"] = time.time_ns()
new_msg = self._get_message()
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.end"] = time.time_ns()
new_msgs.append(new_msg)
else:
# apply callback with a rate of callback_interval
i, new_msg = 0, ""
+ decode_counter = 0
while not self._stopped():
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.start"] = time.time_ns()
self._decode(generation_config=generation_config)
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.decode.{decode_counter}.end"] = time.time_ns()
if i % progress_callback.callback_interval == 0 or self._stopped():
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.start"] = time.time_ns()
new_msg = self._get_message()
+ self.energy_events[f"chat.{self.generate_counter}.{idx}.get_message.end"] = time.time_ns()
progress_callback(new_msg)
i += 1
progress_callback(stopped=True)
@@ -999,11 +1015,15 @@ class ChatModule: # pylint: disable=too-many-instance-attributes
app_config_json: str
The partial config that is used to partially override the model configuration.
"""
+ self.energy_events[f"load_model.start"] = time.time_ns()
self._reload_func(lib, model_path, app_config_json)
+ self.energy_events[f"load_model.end"] = time.time_ns()
def _unload(self):
r"""Unload the chat module and clear memory of all loaded models."""
+ self.energy_events[f"unload_model.start"] = time.time_ns()
self._unload_func()
+ self.energy_events[f"unload_model.end"] = time.time_ns()
def _prefill(
self,
@@ -1209,4 +1229,6 @@ class ChatModule: # pylint: disable=too-many-instance-attributes
def _process_system_prompts(self):
r"""Pre-process by prefilling the system prompts, running prior to any user input."""
+ self.energy_events["prompt.system.start"] = time.time_ns()
self._process_system_prompts_func()
+ self.energy_events["prompt.system.end"] = time.time_ns()
diff --git a/python/mlc_chat/cli/chat.py b/python/mlc_chat/cli/chat.py
index 7ec6efb2..96edef2d 100644
--- a/python/mlc_chat/cli/chat.py
+++ b/python/mlc_chat/cli/chat.py
@@ -37,6 +37,12 @@ def main(argv):
default=None,
help=HELP["model_lib_path"] + ' (default: "%(default)s")',
)
+ parser.add_argument(
+ "--energy-events",
+ type=str,
+ default="energy_events.txt",
+ help="Energy events file to use for energy profiling (default: energy_events.txt)"
+ )
parsed = parser.parse_args(argv)
chat(
model=parsed.model,
@@ -44,4 +50,5 @@ def main(argv):
opt=parsed.opt,
overrides=parsed.overrides,
model_lib_path=parsed.model_lib_path,
+ energy_events_filename=parsed.energy_events,
)
diff --git a/python/mlc_chat/interface/chat.py b/python/mlc_chat/interface/chat.py
index cd473f79..3d23df40 100644
--- a/python/mlc_chat/interface/chat.py
+++ b/python/mlc_chat/interface/chat.py
@@ -122,6 +122,7 @@ def chat(
opt: str,
overrides: ChatConfigOverride,
model_lib_path: Optional[str],
+ energy_events_filename: str,
):
"""chat with a model."""
# Set up chat config and generate config
@@ -146,9 +147,12 @@ def chat(
if prompt[:6] == "/reset":
cm.reset_chat()
elif prompt[:5] == "/exit":
+ with open(energy_events_filename, 'w', encoding='utf-8') as f:
+ for event_key, event_value in cm.energy_events.items():
+ f.write(f"{event_key} {event_value}\n")
break
elif prompt[:6] == "/stats":
- print(cm.stats(), flush=True)
+ print(cm.stats(verbose=True), flush=True)
elif prompt[:4] == "/set":
gen_config_overrides = GenerationConfigOverride.from_str(prompt.split()[1])
generate_config = gen_config_overrides.apply(generate_config)
--
2.43.0