Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Dec 5, 2024
1 parent 9eabbee commit 3adee99
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,32 +141,32 @@ def update_fn(


def copy_partition(
param_specs: Nested[ParameterSpec],
specs: Nested[OptStateSpec],
*,
pattern: Union[None, str, re.Pattern] = None,
memory_kind: Optional[MemoryKind] = None,
) -> Nested[OptStateSpec]:
"""Creates OptStateSpec from ParameterSpec with possibly a different memory kind.
"""Copies OptStateSpec and optionally assigns with a different memory kind.
Args:
param_specs: Nested[ParameterSpec] to copy from.
specs: Nested[OptStateSpec] to copy from.
pattern: Regex to match the full path of each spec. Matched specs will have their memory
kind replaced with `memory_kind`.
memory_kind: New memory kind. Default to None.
Returns:
A Nested[OptStateSpec] with possibly a different memory kind.
"""
return jax.tree.map(
lambda path, param_spec: OptStateSpec(
dtype=param_spec.dtype,
shape=param_spec.shape,
mesh_axes=param_spec.mesh_axes,
lambda path, spec: OptStateSpec(
dtype=spec.dtype,
shape=spec.shape,
mesh_axes=spec.mesh_axes,
memory_kind=memory_kind
if pattern and re.fullmatch(pattern, path)
else param_spec.memory_kind,
else spec.memory_kind,
),
tree_paths(param_specs),
param_specs,
tree_paths(specs),
specs,
)


Expand Down Expand Up @@ -2085,6 +2085,13 @@ def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
# it's specified in the API signature. Reference:
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
# Note: device_put doesn't move everything at once. When we pass a pytree of arrays to
# device_put, each array in the pytree is moved independent of one another. The exact order
# is decided by the latency hiding scheduler. The scheduler will try to overlap the
# transfers of each state with the state update on TPU whenever possible. There is some
# memory spike due the the temporary state in HBM, but the spike is much less than the full
# memory usage of all states. Moreover, when the optimizer is run, all activations are
# released, so we have less memory pressure at that point in time.
return jax.tree.map(
lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst))
if re.fullmatch(pattern, path)
Expand Down

0 comments on commit 3adee99

Please sign in to comment.