Skip to content

Commit

Permalink
fix flamingo init
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla authored Mar 20, 2024
1 parent 52ca075 commit 292afa1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def __init__(
gradient_checkpointing=gradient_checkpointing,
)
self.lang_model.set_decoder_layers_attr_name(decoder_layers_attr_name)
self.decoder_layers_attr_name = decoder_layers_attr_name
self.lang_model.init_cross_attention_layers(
lang_hidden_size=self.lang_hidden_dim,
vis_hidden_size=self.vis_embedding_dim,
Expand Down Expand Up @@ -491,7 +492,7 @@ def lambda_fn(module: nn.Module):
return True
if isinstance(module, GatedCrossAttentionBlock):
return True
if isinstance(module, original_decoder_block_class):
if isinstance(module, decoder_block_class):
return True

return lambda_fn
Expand Down

0 comments on commit 292afa1

Please sign in to comment.