Skip to content

Commit

Permalink
Fix LM kernels
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Bataev <[email protected]>
  • Loading branch information
artbataev committed Feb 13, 2025
1 parent b64a260 commit 4af8aa2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
7 changes: 5 additions & 2 deletions nemo/collections/asr/parts/submodules/ngram_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
if TRITON_AVAILABLE:
import triton

from nemo.collections.asr.parts.submodules.ngram_lm_triton import _ngram_advance_triton_kernel
from nemo.collections.asr.parts.submodules.ngram_lm_triton import (
_ngram_advance_triton_kernel,
_ngram_advance_triton_kernel_v2,
)

Check notice

Code scanning / CodeQL

Unused import Note

Import of '_ngram_advance_triton_kernel' is not used.


def _log_10_to_e(score):
Expand Down Expand Up @@ -864,7 +867,7 @@ def _advance_triton(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
scores = torch.empty([batch_size, self.vocab_size], device=device, dtype=self.arcs_weights.dtype)
new_states = torch.empty([batch_size, self.vocab_size], dtype=torch.long, device=device)

_ngram_advance_triton_kernel[batch_size,](
_ngram_advance_triton_kernel_v2[batch_size,](
vocab_size=self.vocab_size,
states_ptr=states,
new_states_ptr=new_states,
Expand Down
23 changes: 9 additions & 14 deletions nemo/collections/asr/parts/submodules/ngram_lm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def _ngram_advance_triton_kernel(
tl.store(new_states_ptr + batch_i * vocab_size + vocab_offsets, -1, mask=vocab_mask)
tl.store(scores_ptr + batch_i * vocab_size + vocab_offsets, 0.0, mask=vocab_mask)

# done = False
accumulated_backoff = 0.0

for i in range(max_order):
start_idx = tl.load(state_start_arcs_ptr + cur_state)
end_idx = tl.load(state_end_arcs_ptr + cur_state)
Expand All @@ -53,24 +54,17 @@ def _ngram_advance_triton_kernel(
cur_to_states = tl.load(to_states_ptr + indices, mask=mask)

not_final_mask = tl.load(new_states_ptr + batch_i * vocab_size + cur_ilabels, mask=mask, other=0) == -1
# not_final_mask &= mask
tl.store(
scores_ptr + batch_i * vocab_size + cur_ilabels,
tl.load(scores_ptr + batch_i * vocab_size + cur_ilabels, mask=mask) + cur_weights,
cur_weights + accumulated_backoff,
mask=not_final_mask,
)
tl.store(new_states_ptr + batch_i * vocab_size + cur_ilabels, cur_to_states, mask=not_final_mask)

# done |= (cur_state == start_state)
# backoff
cur_backoff_weight = tl.load(backoff_weights_ptr + cur_state)
not_final_mask = tl.load(new_states_ptr + batch_i * vocab_size + vocab_offsets, mask=vocab_mask, other=0) == -1
tl.store(
scores_ptr + batch_i * vocab_size + vocab_offsets,
tl.load(scores_ptr + batch_i * vocab_size + vocab_offsets, mask=vocab_mask) + cur_backoff_weight,
mask=not_final_mask,
)
accumulated_backoff += cur_backoff_weight
cur_state = tl.load(backoff_to_states_ptr + cur_state).to(states_ptr.dtype.element_ty)
tl.debug_barrier()


@triton.jit
Expand Down Expand Up @@ -99,8 +93,8 @@ def _ngram_advance_triton_kernel_v2(
tl.store(scores_ptr + batch_i * vocab_size + vocab_offsets, 0.0, mask=vocab_mask)

accumulated_backoff = 0.0
done = False
while not done:
start_state_not_processed = True
while start_state_not_processed:
start_idx = tl.load(state_start_arcs_ptr + cur_state)
end_idx = tl.load(state_end_arcs_ptr + cur_state)
indices = start_idx + vocab_offsets
Expand All @@ -118,8 +112,9 @@ def _ngram_advance_triton_kernel_v2(
)
tl.store(new_states_ptr + batch_i * vocab_size + cur_ilabels, cur_to_states, mask=not_final_mask)

done = cur_state == start_state
start_state_not_processed = cur_state != start_state
# backoff
cur_backoff_weight = tl.load(backoff_weights_ptr + cur_state)
accumulated_backoff += cur_backoff_weight
cur_state = tl.load(backoff_to_states_ptr + cur_state).to(states_ptr.dtype.element_ty)
tl.debug_barrier()

0 comments on commit 4af8aa2

Please sign in to comment.