Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 9, 2025
1 parent 57c8f34 commit b700aff
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
16 changes: 11 additions & 5 deletions tests/pytorch/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,9 @@ def backward_wrapper(
lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask")
lambda: te_permute(
te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
)
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")

Expand Down Expand Up @@ -968,21 +970,21 @@ def _test_permutation_mask_map_alongside_probs(
idx = random.randint(0, num_expert * tp_size - 1)
split_sizes[idx] += 1
split_sizes = torch.tensor(split_sizes, dtype=torch.int32)
split_sizes_cuda = split_sizes.to(device='cuda')
split_sizes_cuda = split_sizes.to(device="cuda")

_sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32)
sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel()
sorted_idxs_cuda = sorted_idxs.to(device="cuda")

split_sizes_2 = [split_sizes[i] for i in sorted_idxs.tolist()]
split_sizes_2 = torch.tensor(split_sizes_2, dtype=torch.int32)
split_sizes_2_cuda = split_sizes_2.to(device='cuda')
split_sizes_2_cuda = split_sizes_2.to(device="cuda")

sorted_idxs_2 = [0] * (num_expert * tp_size)
for i in range(num_expert * tp_size):
sorted_idxs_2[sorted_idxs[i]] = i
sorted_idxs_2 = torch.tensor(sorted_idxs_2, dtype=torch.int32)
sorted_idxs_2_cuda = sorted_idxs_2.to(device='cuda')
sorted_idxs_2_cuda = sorted_idxs_2.to(device="cuda")

###################################################################################################################################
#
Expand Down Expand Up @@ -1019,7 +1021,11 @@ def _test_permutation_mask_map_alongside_probs(
te_probs.requires_grad_(True)

te_permute_output, row_id_map, te_permuted_probs = te_permute(
te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask", probs=te_probs
te_permute_fwd_input,
routing_map,
num_out_tokens=num_out_tokens,
map_type="mask",
probs=te_probs,
)

te_permute_output, te_permuted_probs = te_sort_chunks_by_index(
Expand Down
26 changes: 14 additions & 12 deletions transformer_engine/pytorch/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ def forward(
assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
if merging_probs.dtype != torch.float32:
warnings.warn(
f"The data type of the input `merging_probs` of Unpermute is {merging_probs.dtype}! "
"The recommended type is torch.float32."
"The data type of the input `merging_probs` of Unpermute is"
f" {merging_probs.dtype}! The recommended type is torch.float32."
)
merging_probs = merging_probs.to(torch.float32)

Expand Down Expand Up @@ -457,16 +457,18 @@ def backward(ctx, unpermuted_act_grad):
fp8_dtype = None

if ctx.with_probs:
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad,
row_id_map,
fwd_input,
merging_probs,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
act_grad, probs_grad = (
triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad,
row_id_map,
fwd_input,
merging_probs,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
)
)
else:
act_grad, _ = triton_permutation.permute_with_mask_map(
Expand Down
31 changes: 24 additions & 7 deletions transformer_engine/pytorch/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,18 @@ def _unpermute_kernel(
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if current_start == 0:
unpermuted_prob_off = pid * stride_unpermuted_probs_token + expert_idx * stride_unpermuted_probs_expert
unpermuted_prob_off = (
pid * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
if src_row != -1:
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
Expand Down Expand Up @@ -309,7 +314,9 @@ def unpermute_with_mask_map(
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None:
unpermuted_probs = torch.empty((num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda")
unpermuted_probs = torch.empty(
(num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
)
else:
unpermuted_probs = None
grid = (num_tokens,)
Expand Down Expand Up @@ -405,7 +412,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
merging_prob_off = pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
if FP8_DTYPE is not None:
Expand All @@ -425,10 +434,16 @@ def _unpermute_bwd_with_merging_probs_kernel(
prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32)
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum)
probs_grad_off = pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
else:
probs_grad_off = pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0)


Expand All @@ -453,7 +468,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
merging_probs_grad = torch.empty((num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda")
merging_probs_grad = torch.empty(
(num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
)
grid = (num_tokens,)
_unpermute_bwd_with_merging_probs_kernel[grid](
fwd_output_grad,
Expand Down

0 comments on commit b700aff

Please sign in to comment.