From 3adee99c6f012d91c9129247ef135a060332fb3f Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Thu, 5 Dec 2024 15:57:42 -0800 Subject: [PATCH] add comments --- axlearn/common/optimizers.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index c120d6b13..5444435c4 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -141,15 +141,15 @@ def update_fn( def copy_partition( - param_specs: Nested[ParameterSpec], + specs: Nested[OptStateSpec], *, pattern: Union[None, str, re.Pattern] = None, memory_kind: Optional[MemoryKind] = None, ) -> Nested[OptStateSpec]: - """Creates OptStateSpec from ParameterSpec with possibly a different memory kind. + """Copies OptStateSpec and optionally assigns with a different memory kind. Args: - param_specs: Nested[ParameterSpec] to copy from. + specs: Nested[OptStateSpec] to copy from. pattern: Regex to match the full path of each spec. Matched specs will have their memory kind replaced with `memory_kind`. memory_kind: New memory kind. Default to None. @@ -157,16 +157,16 @@ def copy_partition( A Nested[OptStateSpec] with possibly a different memory kind. """ return jax.tree.map( - lambda path, param_spec: OptStateSpec( - dtype=param_spec.dtype, - shape=param_spec.shape, - mesh_axes=param_spec.mesh_axes, + lambda path, spec: OptStateSpec( + dtype=spec.dtype, + shape=spec.shape, + mesh_axes=spec.mesh_axes, memory_kind=memory_kind if pattern and re.fullmatch(pattern, path) - else param_spec.memory_kind, + else spec.memory_kind, ), - tree_paths(param_specs), - param_specs, + tree_paths(specs), + specs, ) @@ -2085,6 +2085,13 @@ def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState: # sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it, # it's specified in the API signature. Reference: # https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220 + # Note: device_put doesn't move everything at once. When we pass a pytree of arrays to + # device_put, each array in the pytree is moved independent of one another. The exact order + # is decided by the latency hiding scheduler. The scheduler will try to overlap the + # transfers of each state with the state update on TPU whenever possible. There is some + # memory spike due the the temporary state in HBM, but the spike is much less than the full + # memory usage of all states. Moreover, when the optimizer is run, all activations are + # released, so we have less memory pressure at that point in time. return jax.tree.map( lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst)) if re.fullmatch(pattern, path)