Skip to content

Commit

Permalink
[SW_212175] FLAN-T5 has bad performance when using regional compilati…
Browse files Browse the repository at this point in the history
…on with module.compile (#77)
  • Loading branch information
chaojun-zhang authored and astachowiczhabana committed Jan 7, 2025
1 parent f48dda8 commit 1f70f82
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
14 changes: 14 additions & 0 deletions optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,20 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

<<<<<<< HEAD
=======
def compile_regions(self, model):
if isinstance(model, torch.nn.ModuleList):
for name, module in model.named_children():
if self.dynamic is not None:
module.compile(dynamic=self.dynamic, **self.state.dynamo_plugin.to_kwargs())
else:
module.compile(**self.state.dynamo_plugin.to_kwargs())
else:
for _, module in model.named_children():
self.compile_regions(module)

>>>>>>> 6766d6b5 ([SW_212175] FLAN-T5 has bad performance when using regional compilation with module.compile (#77))
def _prepare_deepspeed(self, *args):
import deepspeed

Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def gaudi_T5Stack_forward(

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.forward,
layer_module.__call__,
hidden_states,
extended_attention_mask,
position_bias,
Expand Down

0 comments on commit 1f70f82

Please sign in to comment.