diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index bdec8a55fcbc..9f7dab077c3c 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -205,8 +205,10 @@ def construct_parameter_trace_from_module_trace(self): def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" if self.__inflight_param_registry: - raise RuntimeError(f"still have inflight params " - f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}") + for param, handle in self.__inflight_param_registry.items(): + handle.wait() + self.__release_param(param) + self.__inflight_param_registry.clear() if not self.is_complete_trace(): # not self.trace_complete: # Make sure that recorded submodule orders are identical across ranks @@ -409,7 +411,7 @@ def release_and_reset_all(self, module: Module) -> None: """release all module parameters""" for param in iter_params(module, recurse=True): if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") + self.__inflight_param_registry.pop(param).wait() # TODO. make this throw if if there are still active submodules. currently # there's a hook execution issue