Skip to content

Commit

Permalink
related-change with deepspeed#5445
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Dec 25, 2024
1 parent 89bb319 commit c6fa522
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
7 changes: 7 additions & 0 deletions intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def IPEX_WEIGHT_PREPACK_MODULE_CPU():
deepspeed_modules_mapping.update(
{LmHeadLinearAllreduce: _IPEXLmHeadLinearAllreduce}
)
if len(deepspeed_modules) > 3:
for module in deepspeed_modules[3:]:
if module not in deepspeed_modules_mapping:
if issubclass(module, LinearAllreduce):
deepspeed_modules_mapping[module] = _IPEXLinearAllreduce
elif issubclass(module, LinearLayer):
deepspeed_modules_mapping[module] = _IPEXLinear
torch_modules.update(deepspeed_modules_mapping)

return torch_modules
Expand Down
3 changes: 3 additions & 0 deletions intel_extension_for_pytorch/nn/utils/_weight_prepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def may_import_deepspeed_modules():
try:
# import deepspeed in a global space will raise circular import error
# intel-extension-for-deepspeed imports both IPEX and deepspeed
import deepspeed.module_inject.layers as dslayers
from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer

ds_layers = [LinearAllreduce, LinearLayer]
Expand All @@ -110,6 +111,8 @@ def may_import_deepspeed_modules():
from deepspeed.module_inject.layers import LmHeadLinearAllreduce

ds_layers.append(LmHeadLinearAllreduce)
ds_layers += [cls for cls in dslayers.LinearAllreduce.__subclasses__()]
ds_layers += [cls for cls in dslayers.LinearLayer.__subclasses__()]
return ds_layers
except ImportError:
return ds_layers
Expand Down

0 comments on commit c6fa522

Please sign in to comment.