Skip to content

Commit

Permalink
Add workaround in IPEX Triton bench (#3971)
Browse files Browse the repository at this point in the history
* Add legacy triton spirv-target launcher support.

* Work around that the XPU event cannot capture the Triton kernel time.
  • Loading branch information
chengjunlu authored Mar 27, 2024
1 parent e842872 commit 10fbce1
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions intel_extension_for_pytorch/_inductor/xpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

import torch

from datetime import datetime

def synchronize():
import torch
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.xpu.is_available():
torch.xpu.synchronize()


def triton_do_bench(
fn,
Expand Down Expand Up @@ -58,6 +67,8 @@ def triton_do_bench(
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.xpu.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.xpu.Event(enable_timing=True) for i in range(n_repeat)]
start_times = [datetime for i in range(n_repeat)]
end_times = [datetime for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
Expand All @@ -71,15 +82,21 @@ def triton_do_bench(
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
synchronize()
# record time of `fn`
start_event[i].record()
start_times[i] = datetime.now()
fn()
end_event[i].record()
synchronize()
end_times[i] = datetime.now()
# Record clocks
torch.xpu.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float
)
# times = torch.tensor(
# [s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float
# )
times = torch.tensor([(e.timestamp() - s.timestamp()) * 1000 for s, e in zip(start_times, end_times)],
dtype=torch.float)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
Expand Down

0 comments on commit 10fbce1

Please sign in to comment.