Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast N-Gram LM on GPU + greedy decoding (RNN-T, TDT, CTC) #10989

Draft
wants to merge 56 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
32aad14
GPU LM Decoding
artbataev Oct 22, 2024
0e24533
Merge branch 'main' into gpu_lm_decoding
artbataev Oct 22, 2024
520ec2f
Merge branch 'main' into gpu_lm_decoding
artbataev Oct 22, 2024
5324114
Merge branch 'main' into gpu_lm_decoding
artbataev Nov 11, 2024
b303ad4
Merge branch 'main' into gpu_lm_decoding
artbataev Nov 14, 2024
d2c1643
Add copyright
artbataev Nov 14, 2024
5b67a60
Update baseline
artbataev Nov 14, 2024
5791bc4
Merge branch 'main' into gpu_lm_decoding
artbataev Dec 13, 2024
24e18a2
Merge branch 'gpu_lm_decoding' of github.com:artbataev/NeMo into gpu_…
artbataev Dec 13, 2024
8387a75
Refactor
artbataev Dec 13, 2024
8e1245d
Separate Triton kernel
artbataev Dec 13, 2024
3dd907a
Apply isort and black reformatting
artbataev Dec 13, 2024
dc10eed
Fix guards
artbataev Dec 13, 2024
64b1a3b
Merge remote-tracking branch 'artbataev/gpu_lm_decoding' into gpu_lm_…
artbataev Dec 13, 2024
0428d83
Apply isort and black reformatting
artbataev Dec 13, 2024
f857513
Merge branch 'main' into gpu_lm_decoding
artbataev Dec 16, 2024
f395bbc
Add test. Clean up code
artbataev Dec 16, 2024
6da9275
Apply isort and black reformatting
artbataev Dec 16, 2024
53e2035
Refactor. Add stubs for differentiable version
artbataev Dec 16, 2024
8bbbe92
Apply isort and black reformatting
artbataev Dec 16, 2024
cca9fe0
Separate constructor from ARPA
artbataev Dec 16, 2024
eb70a69
Apply isort and black reformatting
artbataev Dec 16, 2024
84965a4
Fix implementation. Add docstrings
artbataev Dec 16, 2024
439c71a
Apply isort and black reformatting
artbataev Dec 16, 2024
ae956b6
Clean up code
artbataev Dec 16, 2024
acf4452
Merge remote-tracking branch 'artbataev/gpu_lm_decoding' into gpu_lm_…
artbataev Dec 16, 2024
79d7410
Merge branch 'main' into gpu_lm_decoding
artbataev Dec 17, 2024
5f45b76
Clean up optional_libs
artbataev Dec 17, 2024
4e2264a
Merge branch 'main' into gpu_lm_decoding
artbataev Dec 18, 2024
c82f7f9
Refactor API. Add tests.
artbataev Dec 18, 2024
d71d84e
Apply isort and black reformatting
artbataev Dec 18, 2024
21c6331
Fix test
artbataev Dec 18, 2024
dc8c97d
Merge remote-tracking branch 'artbataev/gpu_lm_decoding' into gpu_lm_…
artbataev Dec 18, 2024
5b71766
Fix test
artbataev Dec 18, 2024
395bbde
Fix dimension
artbataev Dec 18, 2024
7bfe88c
Apply isort and black reformatting
artbataev Dec 18, 2024
53b1e14
Improve memory usage when loading from ARPA + use int32 when possible
artbataev Jan 7, 2025
54ace71
Merge branch 'main' into gpu_lm_decoding
artbataev Jan 7, 2025
9488ca6
Greatly improve memory usage when loading from ARPA
artbataev Jan 9, 2025
95ed27d
Fix ngram structure
artbataev Jan 10, 2025
cb9e324
Speedup and save memory when loading LM from Arpa
artbataev Jan 13, 2025
4128c05
Update GPU-LM (final weight, normalize unk)
artbataev Jan 29, 2025
1dd8183
Apply isort and black reformatting
artbataev Jan 29, 2025
d0dec2a
Load GPU-LM from NeMo file
artbataev Feb 13, 2025
dd03049
Save NeMo file with KenLM
artbataev Feb 13, 2025
3ed3ca5
Apply isort and black reformatting
artbataev Feb 13, 2025
33ea8cb
Fix CTC model vocab size query
artbataev Feb 13, 2025
ca3809c
Merge remote-tracking branch 'artbataev/gpu_lm_decoding' into gpu_lm_…
artbataev Feb 13, 2025
472a247
Fix model loading
artbataev Feb 13, 2025
57c13dc
Fix triton usage
artbataev Feb 13, 2025
804c3b7
Fix final weights
artbataev Feb 13, 2025
b64a260
Fix LM serialization for hybrid models
artbataev Feb 13, 2025
4af8aa2
Fix LM kernels
artbataev Feb 13, 2025
fb57728
Move synchronization to top
artbataev Feb 13, 2025
6479e28
Add lm build options
artbataev Feb 14, 2025
1b304ed
Fix LM building for jsonl
artbataev Feb 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[
compute_timestamps=self.compute_timestamps,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
ngram_lm_model=self.cfg.greedy.get("ngram_lm_model", None),
ngram_lm_alpha=self.cfg.greedy.get("ngram_lm_alpha", 0.0),
)

elif self.cfg.strategy == 'beam':
Expand Down
125 changes: 124 additions & 1 deletion nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import torch
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.parts.submodules.ngram_lm import FastNGramLM
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.core.classes import Typing, typecheck
Expand Down Expand Up @@ -337,6 +339,8 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin):

"""

ngram_lm_batch: Optional[FastNGramLM]

@property
def input_types(self):
"""Returns definitions of module input ports."""
Expand All @@ -360,6 +364,8 @@ def __init__(
compute_timestamps: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
ngram_lm_model: Optional[str | Path] = None,
ngram_lm_alpha: float = 0.0,
):
super().__init__()

Expand All @@ -372,6 +378,14 @@ def __init__(
# set confidence calculation method
self._init_confidence_method(confidence_method_cfg)

# init ngram lm
if ngram_lm_model is not None:
self.ngram_lm_batch = FastNGramLM.from_file(lm_path=ngram_lm_model, vocab_size=self.blank_id)
else:
self.ngram_lm_batch = None
self.ngram_lm_alpha = ngram_lm_alpha
self._repeated_symbols_allowed = True

@typecheck()
def forward(
self,
Expand Down Expand Up @@ -407,9 +421,15 @@ def forward(
decoder_lengths = decoder_lengths.to(decoder_output.device)

if decoder_output.ndim == 2:
if self.ngram_lm_batch is not None:
raise NotImplementedError
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
if self.ngram_lm_batch is None:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
else:
self.ngram_lm_batch.to(decoder_output.device)
hypotheses = self._greedy_decode_logprobs_batched_lm(decoder_output, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, input_decoder_lengths)
return (packed_result,)

Expand Down Expand Up @@ -515,6 +535,106 @@ def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor):

return hypotheses

@torch.no_grad()
def _greedy_decode_logprobs_batched_lm(self, x: torch.Tensor, out_len: torch.Tensor):
# x: [B, T, D]
# out_len: [B]

batch_size = x.shape[0]
max_time = x.shape[1]

device = x.device
log_probs = x
float_dtype = log_probs.dtype

batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True)
batch_indices = torch.arange(batch_size, device=device, dtype=torch.long)
predictions_labels = torch.zeros([batch_size, max_time], device=device, dtype=torch.long)
last_labels = torch.full([batch_size], fill_value=self.blank_id, device=device, dtype=torch.long)
predictions_logprobs = torch.zeros([batch_size, max_time], device=device, dtype=float_dtype)
for i in range(max_time):
lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states)
lm_scores = lm_scores.to(dtype=float_dtype)

labels = torch.argmax(log_probs[:, i], dim=-1)
# lm_scores[batch_indices[last_labels != self.blank_id], last_labels[last_labels != self.blank_id]] = 0.0
labels_w_lm = (log_probs[:, i, :-1] + self.ngram_lm_alpha * lm_scores).argmax(dim=-1)
if self._repeated_symbols_allowed:
# is_blank = (labels == self.blank_id)
# torch.where(is_blank, labels, labels_w_lm, out=labels)
blank_or_repeated = (labels == self.blank_id) | (labels == last_labels) | (labels_w_lm == last_labels)
torch.where(blank_or_repeated, labels, labels_w_lm, out=labels)
blank_or_repeated = (labels == self.blank_id) | (labels == last_labels)
torch.where(
blank_or_repeated,
batch_lm_states,
batch_lm_states_candidates[batch_indices, labels * ~blank_or_repeated],
out=batch_lm_states,
)
else:
blank_mask = labels == self.blank_id
torch.where(blank_mask, labels, labels_w_lm, out=labels)
torch.where(
blank_mask,
batch_lm_states,
batch_lm_states_candidates[batch_indices, labels * ~blank_mask],
out=batch_lm_states,
)
predictions_labels[:, i] = labels
# TODO: logprobs
last_labels = labels

# In CTC greedy decoding, each output maximum likelihood token
# is calculated independent of the other tokens.
# predictions_logprobs, predictions_labels = predictions.max(dim=-1)

# Since predictions_logprobs is a padded matrix in the time
# dimension, we consider invalid timesteps to be "blank".
time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1))
# Sum the non-blank labels to compute the score of the
# transcription. This follows from Eq. (3) of "Connectionist
# Temporal Classification: Labelling Unsegmented Sequence Data
# with Recurrent Neural Networks".
scores = torch.where(non_blank_ids_mask, predictions_logprobs, 0.0).sum(axis=1)

scores = scores.cpu()
predictions_labels = predictions_labels.cpu()
out_len = out_len.cpu()

predictions = log_probs
if self.preserve_alignments or self.preserve_frame_confidence:
predictions = predictions.cpu()

hypotheses = []

# This mimics the for loop in GreedyCTCInfer::forward.
for i in range(batch_size):
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)
hypothesis.score = scores[i]

prediction_labels_no_padding = predictions_labels[i, : out_len[i]].tolist()

assert predictions_labels.dtype == torch.int64
hypothesis.y_sequence = prediction_labels_no_padding

if self.preserve_alignments:
hypothesis.alignments = (
predictions[i, : out_len[i], :].clone(),
predictions_labels[i, : out_len[i]].clone(),
)
if self.compute_timestamps:
# TOOD: Could do this in a vectorized manner... Would
# prefer to have nonzero_static, though, for sanity.
# Or do a prefix sum on out_len
hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist()
if self.preserve_frame_confidence:
hypothesis.frame_confidence = self._get_confidence(predictions[i, : out_len[i], :])

hypotheses.append(hypothesis)

return hypotheses

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

Expand All @@ -526,6 +646,9 @@ class GreedyCTCInferConfig:
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())

ngram_lm_model: Optional[str] = None
ngram_lm_alpha: float = 0.0

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_method_cfg = OmegaConf.structured(
Expand Down
Loading
Loading