Skip to content

Commit

Permalink
clean up inflight status params
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Sep 21, 2024
1 parent 8f81634 commit 7bc3c66
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ def construct_parameter_trace_from_module_trace(self):
def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model"""
if self.__inflight_param_registry:
raise RuntimeError(f"still have inflight params "
f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}")
for param, handle in self.__inflight_param_registry.items():
handle.wait()
self.__release_param(param)
self.__inflight_param_registry.clear()

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded submodule orders are identical across ranks
Expand Down Expand Up @@ -409,7 +411,7 @@ def release_and_reset_all(self, module: Module) -> None:
"""release all module parameters"""
for param in iter_params(module, recurse=True):
if param in self.__inflight_param_registry:
raise RuntimeError(f"param {param.ds_summary()} still in flight")
self.__inflight_param_registry.pop(param).wait()

# TODO. make this throw if if there are still active submodules. currently
# there's a hook execution issue
Expand Down

0 comments on commit 7bc3c66

Please sign in to comment.