Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 9, 2024
1 parent bb59538 commit 2723460
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 28 deletions.
52 changes: 37 additions & 15 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,19 @@ def session(self):
self.start_time = None

def count(self) -> int:
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"
assert len(self.start_events) == len(self.end_events)

return len(self.start_events)

def elapsed(self):
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

return time.perf_counter() - self.start_time

@contextmanager
def track(self):
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

if self.is_pytorch_cuda:
start_event = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -321,7 +321,7 @@ def session(self):
self.start_time = None

def count(self) -> int:
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"
assert (
len(self.prefill_start_events)
== len(self.prefill_end_events)
Expand All @@ -332,12 +332,14 @@ def count(self) -> int:
return len(self.prefill_start_events)

def elapsed(self):
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

return time.perf_counter() - self.start_time

@contextmanager
def track(self):
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

if self.is_pytorch_cuda:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -357,14 +359,16 @@ def track(self):
self.per_token_end_events.extend(self.per_token_events[1:])

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

if self.is_pytorch_cuda:
event = torch.cuda.Event(enable_timing=True)
event.record()
else:
event = time.perf_counter()

if len(self.prefill_start_events) == len(self.prefill_end_events):
# on the first call, there will be the same number of prefill/decode start/end events
# on the first call (prefill), there will be the same number of prefill/decode start/end events
self.prefill_end_events.append(event)
self.decode_start_events.append(event)

Expand All @@ -373,6 +377,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return scores

def get_prefill_latency(self) -> Latency:
assert len(self.prefill_start_events) == len(self.prefill_end_events) > 0

if self.is_pytorch_cuda:
torch.cuda.synchronize()

Expand All @@ -391,6 +397,8 @@ def get_prefill_latency(self) -> Latency:
return Latency.from_values(latencies, unit=LATENCY_UNIT)

def get_decode_latency(self) -> Latency:
assert len(self.decode_start_events) == len(self.decode_end_events) > 0

if self.is_pytorch_cuda:
torch.cuda.synchronize()

Expand All @@ -409,6 +417,8 @@ def get_decode_latency(self) -> Latency:
return Latency.from_values(latencies, unit=LATENCY_UNIT)

def get_per_token_latency(self) -> Latency:
assert len(self.per_token_start_events) == len(self.per_token_end_events) > 0

if self.is_pytorch_cuda:
torch.cuda.synchronize()

Expand Down Expand Up @@ -464,18 +474,20 @@ def session(self):
self.start_time = None

def count(self) -> int:
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"
assert len(self.call_start_events) == len(self.call_start_events)

return len(self.call_start_events)

def elapsed(self):
assert self.start_time is not None
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

return time.perf_counter() - self.start_time

@contextmanager
def track(self):
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

if self.is_pytorch_cuda:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -495,6 +507,8 @@ def track(self):
self.per_step_end_events.extend(self.per_step_events[1:])

def __call__(self, pipeline, step_index, timestep, callback_kwargs):
assert self.start_time is not None, "This method can only be called inside of a '.session()' context"

if self.is_pytorch_cuda:
event = torch.cuda.Event(enable_timing=True)
event.record()
Expand All @@ -506,6 +520,8 @@ def __call__(self, pipeline, step_index, timestep, callback_kwargs):
return callback_kwargs

def get_step_latency(self) -> Latency:
assert len(self.per_step_start_events) == len(self.per_step_end_events) > 0

if self.is_pytorch_cuda:
torch.cuda.synchronize()

Expand All @@ -524,6 +540,8 @@ def get_step_latency(self) -> Latency:
return Latency.from_values(latencies, unit=LATENCY_UNIT)

def get_call_latency(self) -> Latency:
assert len(self.call_start_events) == len(self.call_end_events) > 0

if self.is_pytorch_cuda:
torch.cuda.synchronize()

Expand Down Expand Up @@ -559,20 +577,24 @@ def __init__(self, device: str, backend: str) -> None:

def on_step_begin(self, *args, **kwargs):
if self.is_pytorch_cuda:
self.start_events.append(torch.cuda.Event(enable_timing=True))
self.end_events.append(torch.cuda.Event(enable_timing=True))
self.start_events[-1].record()
event = torch.cuda.Event(enable_timing=True)
event.record()
else:
self.start_events.append(time.perf_counter())
event = time.perf_counter()

self.start_events.append(event)

def on_step_end(self, *args, **kwargs):
if self.is_pytorch_cuda:
self.end_events[-1].record()
event = torch.cuda.Event(enable_timing=True)
event.record()
else:
self.end_events.append(time.perf_counter())
event = time.perf_counter()

self.end_events.append(event)

def get_latency(self) -> Latency:
assert len(self.start_events) == len(self.end_events) >= 0
assert len(self.start_events) == len(self.end_events) > 0

if self.is_pytorch_cuda:
torch.cuda.synchronize()
Expand Down
31 changes: 18 additions & 13 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from optimum_benchmark.generators.input_generator import InputGenerator
from optimum_benchmark.import_utils import get_git_revision_hash
from optimum_benchmark.system_utils import is_nvidia_system, is_rocm_system
from optimum_benchmark.trackers import LatencyTracker, MemoryTracker
from optimum_benchmark.trackers import LatencySessionTracker, MemoryTracker

PUSH_REPO_ID = os.environ.get("PUSH_REPO_ID", "optimum-benchmark/local")

Expand Down Expand Up @@ -55,6 +55,9 @@
@pytest.mark.parametrize("scenario", ["training", "inference"])
@pytest.mark.parametrize("library,task,model", LIBRARIES_TASKS_MODELS)
def test_api_launch(device, scenario, library, task, model):
if scenario == "training" and library != "transformers":
pytest.skip("Training is only supported with transformers library models")

benchmark_name = f"{device}_{scenario}_{library}_{task}_{model}"

if device == "cuda":
Expand All @@ -65,24 +68,26 @@ def test_api_launch(device, scenario, library, task, model):
elif is_nvidia_system():
device_isolation_action = "error"
device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
else:
raise RuntimeError("Using CUDA device on a machine that is neither NVIDIA nor ROCM.")
else:
device_isolation_action = None
device_isolation = False
device_ids = None

launcher_config = ProcessConfig(device_isolation=device_isolation, device_isolation_action=device_isolation_action)
launcher_config = ProcessConfig(
device_isolation=device_isolation,
device_isolation_action=device_isolation_action,
)

if scenario == "training":
if library == "transformers":
scenario_config = TrainingConfig(
memory=True,
latency=True,
energy=not is_rocm_system(),
warmup_steps=2,
max_steps=5,
)
else:
pytest.skip("Training scenario is only available for Transformers library")
scenario_config = TrainingConfig(
memory=True,
latency=True,
energy=not is_rocm_system(),
warmup_steps=2,
max_steps=5,
)

elif scenario == "inference":
scenario_config = InferenceConfig(
Expand Down Expand Up @@ -227,7 +232,7 @@ def test_api_dataset_generator(library, task, model):
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("backend", ["pytorch", "other"])
def test_api_latency_tracker(device, backend):
tracker = LatencyTracker(device=device, backend=backend)
tracker = LatencySessionTracker(device=device, backend=backend)

with tracker.session():
while tracker.elapsed() < 2:
Expand Down

0 comments on commit 2723460

Please sign in to comment.