Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 1, 2024
1 parent ca63724 commit d02435e
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,20 @@ def get_gfx_version():
print(f"Error: {e}")
return None


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908')
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')


def is_rdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201")
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",
"gfx1102", "gfx1200", "gfx1201")


def get_cdna_autotune_configs():
return [
Expand Down Expand Up @@ -362,6 +368,7 @@ def get_rdna_autotune_configs():
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']


def get_autotune_configs():
if is_rdna():
return get_rdna_autotune_configs()
Expand All @@ -372,6 +379,8 @@ def get_autotune_configs():


autotune_configs, autotune_keys = get_autotune_configs()


@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
Expand Down

0 comments on commit d02435e

Please sign in to comment.