Skip to content

Commit

Permalink
Add padding in mllama vision encoder to align with HF (#11808)
Browse files Browse the repository at this point in the history
* Add padding in visoin_encoder as HF

* cleanup
  • Loading branch information
meatybobby authored Feb 5, 2025
1 parent 922446e commit c95981a
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions nemo/collections/vlm/mllama/model/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,14 +692,21 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor:
x = self.apply_positional_embedding(x, ar_ids)

x = self.ln_pre(x)

# Compute the number of tokens to pad (to be consistent with HF)
npad = (8 - (x.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (0, 0, 0, npad) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
x = F.pad(x, padding, mode="constant", value=0)

x = x.view(bsz * num_concurrent_media, -1, dim)

npad, attn_mask = 0, None
attn_bias = build_encoder_attention_mask(x, ar_ids, ntok, num_chunks, self.config.supported_aspect_ratios)
x = x.transpose(0, 1).contiguous()
x, int_x = self.transformer(
hidden_states=x,
attention_mask=attn_mask,
attention_mask=None,
attention_bias=attn_bias,
return_intermediate=self.return_intermediate,
)
Expand All @@ -719,10 +726,12 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor:
)
x = x.transpose(0, 1)
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
x = x[:, :, :ntok]

# adding back intermediate layer outputs
x = x.reshape(bsz, num_concurrent_media, num_chunks, ntok, dim)
int_x = int_x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, -1)
int_x = int_x[:, :, :ntok]
# int_x = contract_num_tokens_from_mult8(int_x, npad)
int_x = int_x.reshape(bsz, num_concurrent_media, num_chunks, ntok, -1)
x = torch.cat([x, int_x], dim=-1)
Expand Down

0 comments on commit c95981a

Please sign in to comment.