Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed deadlock in sgmv_shrink kernel caused by skewed segments #35

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions csrc/sgmv_flashinfer/sgmv_flashinfer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp,
constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero;
const uint32_t problem_id = blockIdx.y;
const uint32_t bx = blockIdx.x;
const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1];
constexpr uint32_t num_stages = 2;
constexpr uint32_t num_k_frags = 8;
constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity<T>();
Expand All @@ -45,8 +44,9 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp,
uint32_t w_frag[num_k_frags][num_blocks_n][4];
float y_frag[num_blocks_n][8];

for (uint32_t i = 0;
i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) {
const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1];
const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0;
for (uint32_t i = 0; i < num_steps; ++i) {
// init y_frag
if (bx == 0) {
if constexpr (num_blocks_n == 1) {
Expand Down Expand Up @@ -335,6 +335,20 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp,
}
}
}

// handle the case where one of the segments needs more steps than this one
// to avoid deadlock
if constexpr (cooperative) {
uint32_t max_segment_size = 0;
for (uint32_t i = 0; i < num_problems; ++i) {
max_segment_size = max(max_segment_size, s[i + 1] - s[i]);
}

const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16);
for (uint32_t i = 0; i < max_steps - num_steps; ++i) {
grid.sync();
}
}
}

} // namespace sgmv
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sgmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def get_lora_lens(bs: int, popularity: str) -> list[int]:
a *= alpha
lens.append(bs - sum(lens))
return sorted(lens, reverse=True)
if popularity.startswith("skewed"):
if bs < 3:
return [bs]
# Create a highly imbalanced distribution by setting the first segment
# length to 1 and the remainder to the second segment.
return [1, bs - 1]
raise KeyError(popularity)


Expand Down Expand Up @@ -81,7 +87,7 @@ def lora_ref_impl(
pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")),
],
)
@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical"])
@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical", "skewed"])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 7, 10, 16, 32, 64, 133])
@torch.inference_mode()
def test_sgmv_correctness(dtype_str, h, r, direction, popularity, batch_size):
Expand Down