Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why ZeRO1 partitioning the model parameters but not the optimizer state? #1

Open
calico-niko opened this issue Jul 8, 2024 · 1 comment

Comments

@calico-niko
Copy link

Hi, awesome project. Learn a lot from this project, thanks your great work!

but the model partition policy confused me. From this blog I read, the ZeRO1 policy only partitions the model optimizer state (self.m and self.v in code), but the code shown below splits the model parameters. Would you mind explaining why?

for idx, (_, param) in enumerate(self._local_params()):
si_s, si_e = self.shard_indices[idx]
self.sharded_fp32_master_param[si_s:si_e] = param.data.view(-1).float()
# set grad as well.
param.grad = torch.zeros_like(param.data)
param.grad.data = self.local_grad_buffer_hp[si_s:si_e].view_as(param.data)

@MostHumble
Copy link

@calico-niko Not sure if it's useful to you, but here's what I understood:
The optimizer state consists of a copy of the parameters and momentum parameters:

image

Along with the parameter sharding you mentioned, sharding momentums is done here (because of _local_params thus getting a local final current_offset):

current_offset = 0
# Initialize config per-shard.
for gidx, param in self._local_params():
self.offsets.append(param.data.view(-1).size(0))
self.shard_indices.append(
(current_offset, current_offset + param.data.view(-1).size(0))
)
current_offset += param.data.view(-1).size(0)
self.local_param_indices.add(gidx)
self.v = torch.zeros(current_offset).to(self.device)
self.m = torch.zeros(current_offset).to(self.device)
self.sharded_fp32_master_param = torch.zeros(current_offset).to(self.device)
self.local_grad_buffer_hp = torch.zeros(current_offset).to(
self.device, dtype=forward_dtype
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants