Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect patch in zero.init #5921

Closed

Conversation

VeryLazyBoy
Copy link
Contributor

The code below has a problem where cls.__init__ in line 525 can be modified before assignment to _old_init. This could lead to an incorrect __init__ being backed up:

def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_enable_class(subclass)

Test Case

import deepspeed
from torch import nn


class ModelA(nn.Module):
    def __init__(self):
        super().__init__()


class ModelB(ModelA):
    pass


original_init = ModelA.__init__


ds_config = {
    'fp16': {'enabled': False},
    'bf16': {'enabled': True},
    'zero_optimization': {
        'stage': 3,
        'offload_optimizer': {
            'device': 'cpu',
            'pin_memory': True
        },
        'offload_param': {
            'device': 'cpu',
            'pin_memory': True
        },
    },
    'gradient_accumulation_steps': 1,
    'gradient_clipping': 1,
    'train_batch_size': 1,
    'train_micro_batch_size_per_gpu': 1
}


with deepspeed.zero.Init(config_dict_or_path=ds_config, enabled=True, mem_efficient_linear=False, mpu=None):
    model_a = ModelA()
    assert ModelA.__init__ != original_init

assert ModelA.__init__ == original_init
assert ModelB.__init__ == original_init   #  Fails here. If not, please try several times since it depends on the order of modifications

@VeryLazyBoy
Copy link
Contributor Author

@microsoft-github-policy-service agree

@VeryLazyBoy
Copy link
Contributor Author

A better solution is proposed to handle _init_subclass as well

@tjruwase tjruwase requested a review from tohtana August 15, 2024 17:53
@tohtana
Copy link
Contributor

tohtana commented Aug 21, 2024

Thank you @VeryLazyBoy for the great catch!

I think the issue is that we patch superclass's cls.__init__ when cls doesn't have its __init__. So I try another approach in this branch. Do you think if this works?
This is less intrusive as we do not set __init__ when cls doesn't have it.

@VeryLazyBoy
Copy link
Contributor Author

@tohtana Yes! Your approach is less intrusive and much better. Let's go ahead with this new method. Should I close this merge request?

@tohtana
Copy link
Contributor

tohtana commented Aug 21, 2024

@VeryLazyBoy Thank you for your response!
Let me create a PR using my branch to make sure it works. Let's close this PR after all test pass with the PR.

github-merge-queue bot pushed a commit that referenced this pull request Sep 4, 2024
This PR fixes an issue addressed in #5921.
With this change, we only apply the patch for parameter partitioning to
classes that have `__init__` so that we can avoid applying the patch
multiple times.
The class that does not have `__init__` now uses its superclass's one.
So this PR also applies the patch to the root class,
`torch.nn.modules.module.Module`.

Thanks @VeryLazyBoy for the report and initial solution.

---------

Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants