Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the issue during batched inference of Sortformer diarizer #12047

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_diar_label_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
target_fr_len = get_hidden_length_from_sample_length(
audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame
)
target_lens_list.append([target_fr_len])
target_lens_list.append(target_fr_len)
target_lens = torch.tensor(target_lens_list)

return audio, audio_lens, targets, target_lens
15 changes: 10 additions & 5 deletions nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,24 @@ def frontend_encoder(self, processed_signal, processed_signal_length):
emb_seq = self.sortformer_modules.encoder_proj(emb_seq)
return emb_seq, emb_seq_length

def forward_infer(self, emb_seq):
def forward_infer(self, emb_seq, emb_seq_length):
"""
The main forward pass for diarization for offline diarization inference.

Args:
emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors).
Dimension: (batch_size, diar_frame_count, emb_dim)
emb_seq_length (torch.Tensor): tensor containing lengths of FastConformer encoder states.
Dimension: (batch_size,)

Returns:
preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels.
Dimension: (batch_size, diar_frame_count, num_speakers)
"""
encoder_mask = self.sortformer_modules.length_to_mask(emb_seq)
encoder_mask = self.sortformer_modules.length_to_mask(emb_seq_length, emb_seq.shape[1])
trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask)
preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq)
_preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq)
preds = _preds * encoder_mask.unsqueeze(-1)
return preds

def _diarize_forward(self, batch: Any):
Expand Down Expand Up @@ -407,6 +410,8 @@ def process_signal(self, audio_signal, audio_signal_length):
processed_signal, processed_signal_length = self.preprocessor(
input_signal=audio_signal, length=audio_signal_length
)
if not self.training:
torch.cuda.empty_cache()
return processed_signal, processed_signal_length

def forward(
Expand Down Expand Up @@ -434,10 +439,10 @@ def forward(
if self._cfg.get("streaming_mode", False):
raise NotImplementedError("Streaming mode is not implemented yet.")
else:
emb_seq, _ = self.frontend_encoder(
emb_seq, emb_seq_length = self.frontend_encoder(
processed_signal=processed_signal, processed_signal_length=processed_signal_length
)
preds = self.forward_infer(emb_seq)
preds = self.forward_infer(emb_seq, emb_seq_length)
return preds

def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict:
Expand Down
22 changes: 8 additions & 14 deletions nemo/collections/asr/modules/sortformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,22 @@ def __init__(
self.dropout = nn.Dropout(dropout_rate)
self.encoder_proj = nn.Linear(self.fc_d_model, self.tf_d_model)

def length_to_mask(self, context_embs):
def length_to_mask(self, lengths, max_length):
"""
Convert length values to encoder mask input tensor.
Convert length values to encoder mask input tensor

Args:
lengths (torch.Tensor): tensor containing lengths of sequences
max_len (int): maximum sequence length
lengths (torch.Tensor): tensor containing lengths (frame counts) of sequences
max_length (int): maximum length (frame count) of the sequences in the batch

Returns:
mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's
in the padded region and 1's elsewhere
"""
lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0])
batch_size = context_embs.shape[0]
max_len = context_embs.shape[1]
# create a tensor with the shape (batch_size, 1) filled with ones
row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
# create a tensor with the shape (batch_size, max_len) filled with lengths
length_matrix = lengths.unsqueeze(1).expand(-1, max_len).to(lengths.device)
# create a mask by comparing the row vector and length matrix
mask = row_vector < length_matrix
return mask.float().to(context_embs.device)
batch_size = lengths.shape[0]
arange = torch.arange(max_length, device=lengths.device)
mask = arange.expand(batch_size, max_length) < lengths.unsqueeze(1)
return mask

def forward_speaker_sigmoids(self, hidden_out):
"""
Expand Down
Loading