Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streamk v0.1 #619

Merged
merged 6 commits into from
Jul 31, 2024
Merged

streamk v0.1 #619

merged 6 commits into from
Jul 31, 2024

Conversation

xiaohuguo2023
Copy link
Member

Triton stream-k gemm v0.1

  1. comparable performance with tune gemm
  2. persistent non atomic kernel implementation
  3. pid renumbering based on chiplet structure of MI300X
  4. dynamic grid setting
  5. tuning script adapt from tune_gemm

@zhanglx13
Copy link

Can you write a README to introduce the features implemented in this version of the streamK kernel?

@vgokhale
Copy link
Collaborator

persistent non atomic kernel implementation

What does this mean?

@xiaohuguo2023
Copy link
Member Author

persistent non atomic kernel implementation

What does this mean?

In this version, stream-k kernel use the persistent loop so that a WG may work on multiple output tiles, and also allowing workgroups to do part of the work for an output tile.

@vgokhale
Copy link
Collaborator

persistent non atomic kernel implementation

What does this mean?

In this version, stream-k kernel use the persistent loop so that a WG may work on multiple output tiles, and also allowing workgroups to do part of the work for an output tile.

But it uses atomics right? Did you mean non atomic as in does not do atomic add?

@xiaohuguo2023
Copy link
Member Author

persistent non atomic kernel implementation

What does this mean?

In this version, stream-k kernel use the persistent loop so that a WG may work on multiple output tiles, and also allowing workgroups to do part of the work for an output tile.

But it uses atomics right? Did you mean non atomic as in does not do atomic add?

yeah, my description is not precise, we still use atomics for spinning lock, but not atomic_add for the final output.

@xiaohuguo2023
Copy link
Member Author

Can you write a README to introduce the features implemented in this version of the streamK kernel?

done

@@ -0,0 +1,43 @@
# streamk gemm script v0.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be needed to get it to 1.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need made it ready to explore half million benchmarks, and have a comparable performance with Tensile development

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can have it comparable to tensile because that is outside of the scope of streamk. I think we can call this 0.1 until we have the wider tuning space working.

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if EVEN_K:
a = tl.load(A_BASE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we peel the masking for the last iteration when EVEN_K is False so that only the last loop pays the price of the mask?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, this will be in next PR. Thanks !



@triton.jit()
def get_new_pid(current_pid, num_sms):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/s/num_sms/num_cus

# Number of XCDs
num_xcds = 8
# Number of pids per XCD in the new arrangement
pids_per_xcd = num_sms // num_xcds
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the grid can have multiple of num_cus pids.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For persistent kernel, grid has to be either num_cus or total_tiles if total_tiles < num_cus

@xiaohuguo2023 xiaohuguo2023 merged commit 52a908f into main_perf Jul 31, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants