Skip to content

Commit

Permalink
have inference runner use the same custom create_device_mesh as train…
Browse files Browse the repository at this point in the history
…er (#87)

* have inference runner use the same custom create_device_mesh as trainer
  • Loading branch information
JinhaoLei authored Sep 29, 2023
1 parent 0491cf9 commit 70eb15f
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 203 deletions.
3 changes: 1 addition & 2 deletions axlearn/common/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import jax.numpy as jnp
import numpy as np
from absl import logging
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit

from axlearn.common import utils
Expand Down Expand Up @@ -212,7 +211,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
[device.platform for device in jax.local_devices()],
)
logging.info("Mesh shape: %s", cfg.mesh_shape)
devices = mesh_utils.create_device_mesh(cfg.mesh_shape)
devices = utils.create_device_mesh(cfg.mesh_shape)
mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names)
logging.info("Global mesh: %s", mesh)
self._mesh = mesh
Expand Down
94 changes: 1 addition & 93 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union

import jax
import numpy as np
import tensorflow as tf
from absl import logging
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit

from axlearn.common import utils
Expand Down Expand Up @@ -159,7 +157,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
[device.platform for device in jax.local_devices()],
)
self._step_log("Mesh shape: %s", cfg.mesh_shape)
devices = _create_device_mesh(mesh_shape=cfg.mesh_shape)
devices = utils.create_device_mesh(mesh_shape=cfg.mesh_shape)
mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names)
self._step_log("Global mesh: %s", mesh)
self._mesh = mesh
Expand Down Expand Up @@ -747,93 +745,3 @@ def _forward(model_parameters_grad, model_parameters_no_grad, forward_input_batc
loss=loss,
aux=forward_aux,
)


def _create_device_mesh(
mesh_shape: Sequence[int], *, devices: Optional[Sequence[Any]] = None
) -> np.ndarray:
"""Constructs a device mesh.
We first determine whether we are running in a TPU or GPU environment.
- If running in a TPU environment:
- If multi-slice/granule, we split the first axis of the configured
mesh shape across the slices.
- If running in a GPU environment:
- If the first axis divides the number of processes (GPU-nodes/granules), we
split the first axis across the processes.
In all other cases we construct a standard mesh according to the configured mesh_shape.
TODO(tom_gunter): Allow for more inter/intra granule mesh config flexibility.
Args:
mesh_shape: The desired logical mesh shape.
devices: The devices that will be used to construct the mesh.
If None, defaults to jax.devices().
Returns:
A numpy array containing the JAX devices with shape determined by the config mesh_shape.
Raises:
NotImplementedError: If not all devices have the same platform.
"""
if devices is None:
devices = jax.devices()
devices = np.asarray(devices)

def build_standard_mesh():
logging.info("Building device mesh.")
try:
return mesh_utils.create_device_mesh(mesh_shape, devices=devices)
except NotImplementedError as e:
logging.warning(
"mesh_utils.create_device_mesh cannot handle shape %s: %s. "
"Falling back to the naive mesh. Performance may be reduced.",
mesh_shape,
e,
)
return devices.reshape(mesh_shape)

# Check if the devices are part of a multi-granule configuration.
# <https://github.com/google/jax/blob/b81b79c1b0d2ec/jax/experimental/mesh_utils.py#L313>
device_platform = devices[0].platform
attr = "process_index" if device_platform != "tpu" else "slice_index"
is_multi_granule_env = hasattr(devices[0], attr)
if not all(el.platform == device_platform for el in devices):
raise NotImplementedError(f"Not all devices had platform: {device_platform}.")

# Return standard mesh if not a multi-slice/granule env.
if not is_multi_granule_env:
return build_standard_mesh()

ici_mesh_shape = mesh_shape
num_granules = max([getattr(el, attr) for el in devices.flatten()]) + 1

# Return standard mesh if on GPU with incompatible multi-slice/granule mesh.
if device_platform == "gpu" and ici_mesh_shape[0] % num_granules != 0:
logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.")
return build_standard_mesh()

# We only break the first device axis (the least communication intensive) across granules.
assert (
ici_mesh_shape[0] % num_granules == 0
), "First mesh shape axis must divide num slices/granules."
logging.info("Building multi-slice/granule device mesh.")
# Truncate intra-slice/granule mesh.
ici_mesh_shape = (ici_mesh_shape[0] // num_granules, *ici_mesh_shape[1:])
logging.info("Inferred intra-slice/granule mesh shape: %s", ici_mesh_shape)
# Configure data center (inter-slice/granule) mesh.
dcn_mesh_shape = (num_granules,) + (1,) * len(ici_mesh_shape[1:])
logging.info("Inferred inter-slice/granule mesh shape: %s", dcn_mesh_shape)
# Check we have the right number of devices.
total_parallelism = np.product(dcn_mesh_shape) * np.product(ici_mesh_shape)
assert total_parallelism == len(devices), (
f"Num devices {len(devices)} does not match the product of "
f"inter and intra slice/granule parallelism {total_parallelism}."
)
return mesh_utils.create_hybrid_device_mesh(
ici_mesh_shape,
dcn_mesh_shape=dcn_mesh_shape,
devices=devices,
process_is_granule=attr == "process_index",
)
108 changes: 1 addition & 107 deletions axlearn/common/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""Tests SpmdTrainer."""
# pylint: disable=no-self-use
import copy
import dataclasses
import os.path
import shutil
import tempfile
Expand Down Expand Up @@ -32,7 +31,7 @@
from axlearn.common.learner import UpdateType, should_update_with_optimizers
from axlearn.common.module import Module
from axlearn.common.state_builder import Builder as TrainerStateBuilder
from axlearn.common.trainer import SpmdTrainer, _create_device_mesh, _prune_empty, _TrainerState
from axlearn.common.trainer import SpmdTrainer, _prune_empty, _TrainerState
from axlearn.common.utils import NestedTensor, Tensor, as_tensor, flatten_items, match_regex_rules

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -813,110 +812,5 @@ def test_composite_learner(self):
self.assertGreater(np.max(np.abs(updated_p - init_p)), 1e-3, msg=path)


@dataclasses.dataclass(frozen=True)
class DummyDevice:
"""Mock device for testing."""

platform: str
device_kind: str
process_index: int


@dataclasses.dataclass(frozen=True)
class DummyTpuDevice(DummyDevice):
"""Mock TPU device for testing."""

coords: Sequence[int]
core_on_chip: int = 0


@dataclasses.dataclass(frozen=True)
class DummyMultiSliceTpuDevice(DummyTpuDevice):
"""Mock multi-slice TPU device for testing."""

slice_index: int = 0


class DeviceMeshTest(test_utils.TestCase):
@parameterized.parameters(
{"logical_mesh": (2, 8)},
{"logical_mesh": (4, 4)},
{"logical_mesh": (1, 2, 8)},
)
def test_create_device_mesh_tpuv4(self, logical_mesh: Sequence[int]):
physical_mesh = (4, 4, 1)
coords = [
(x, y, z)
for x in range(physical_mesh[0])
for y in range(physical_mesh[1])
for z in range(physical_mesh[2])
]
devices = [
DummyTpuDevice(
platform="tpu",
device_kind="TPU v4",
process_index=ix // 4,
coords=coord,
)
for ix, coord in enumerate(coords)
]
# Check that the constructed mesh has the expected shape.
self.assertEqual(
_create_device_mesh(mesh_shape=logical_mesh, devices=devices).shape, logical_mesh
)

@parameterized.parameters(
{"logical_mesh": (2, 16)},
{"logical_mesh": (2, 4, 4)},
)
def test_create_device_mesh_multi_slice_tpuv4(self, logical_mesh: Sequence[int]):
slice_physical_mesh = (4, 4, 1)
num_slices = 2
coords = [
(x, y, z)
for x in range(slice_physical_mesh[0])
for y in range(slice_physical_mesh[1])
for z in range(slice_physical_mesh[2])
]
devices = [
DummyMultiSliceTpuDevice(
platform="tpu",
device_kind="TPU v4",
process_index=(len(coords) * slice_index + ix) // 4,
coords=coord,
slice_index=slice_index,
)
for ix, coord in enumerate(coords)
for slice_index in range(num_slices)
]
# Check that the constructed mesh has the expected shape.
device_mesh = _create_device_mesh(mesh_shape=logical_mesh, devices=devices)
self.assertEqual(device_mesh.shape, logical_mesh)
# Check that the sub_mesh along the first axis only contains devices from one of the slices.
for ix, sub_mesh in enumerate(device_mesh):
self.assertTrue(all(el.slice_index == ix for el in sub_mesh.flatten()))

@parameterized.parameters(
{"logical_mesh": (8, 2, 4)},
{"logical_mesh": (16, 4)},
{"logical_mesh": (2, 32)},
)
def test_create_device_mesh_gpu(self, logical_mesh: Sequence[int] = (8, 2, 4)):
num_gpus_per_process = 8
num_granules = 8
devices = [
DummyDevice(
platform="gpu",
device_kind="gpu",
process_index=(num_gpus_per_process * granule_index + ix) // num_gpus_per_process,
)
for ix in range(num_gpus_per_process)
for granule_index in range(num_granules)
]
# Check that the constructed mesh has the expected shape.
device_mesh = _create_device_mesh(mesh_shape=logical_mesh, devices=devices)
self.assertEqual(device_mesh.shape, logical_mesh)


if __name__ == "__main__":
absltest.main()
92 changes: 91 additions & 1 deletion axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from absl import logging
from flax import serialization
from jax import numpy as jnp
from jax.experimental import maps, multihost_utils, pjit
from jax.experimental import maps, mesh_utils, multihost_utils, pjit
from jax.sharding import PartitionSpec
from jax.tree_util import register_pytree_node_class

Expand Down Expand Up @@ -922,3 +922,93 @@ def register_per_param_settings(settings: NestedTree, *, description: str):
for path, setting in flatten_items(settings):
logging.info("Per-param setting %s: %s=%s", description, path, setting)
return settings


def create_device_mesh(
mesh_shape: Sequence[int], *, devices: Optional[Sequence[Any]] = None
) -> np.ndarray:
"""Constructs a device mesh.
We first determine whether we are running in a TPU or GPU environment.
- If running in a TPU environment:
- If multi-slice/granule, we split the first axis of the configured
mesh shape across the slices.
- If running in a GPU environment:
- If the first axis divides the number of processes (GPU-nodes/granules), we
split the first axis across the processes.
In all other cases we construct a standard mesh according to the configured mesh_shape.
TODO(tom_gunter): Allow for more inter/intra granule mesh config flexibility.
Args:
mesh_shape: The desired logical mesh shape.
devices: The devices that will be used to construct the mesh.
If None, defaults to jax.devices().
Returns:
A numpy array containing the JAX devices with shape determined by the config mesh_shape.
Raises:
NotImplementedError: If not all devices have the same platform.
"""
if devices is None:
devices = jax.devices()
devices = np.asarray(devices)

def build_standard_mesh():
logging.info("Building device mesh.")
try:
return mesh_utils.create_device_mesh(mesh_shape, devices=devices)
except NotImplementedError as e:
logging.warning(
"mesh_utils.create_device_mesh cannot handle shape %s: %s. "
"Falling back to the naive mesh. Performance may be reduced.",
mesh_shape,
e,
)
return devices.reshape(mesh_shape)

# Check if the devices are part of a multi-granule configuration.
# <https://github.com/google/jax/blob/b81b79c1b0d2ec/jax/experimental/mesh_utils.py#L313>
device_platform = devices[0].platform
attr = "process_index" if device_platform != "tpu" else "slice_index"
is_multi_granule_env = hasattr(devices[0], attr)
if not all(el.platform == device_platform for el in devices):
raise NotImplementedError(f"Not all devices had platform: {device_platform}.")

# Return standard mesh if not a multi-slice/granule env.
if not is_multi_granule_env:
return build_standard_mesh()

ici_mesh_shape = mesh_shape
num_granules = max([getattr(el, attr) for el in devices.flatten()]) + 1

# Return standard mesh if on GPU with incompatible multi-slice/granule mesh.
if device_platform == "gpu" and ici_mesh_shape[0] % num_granules != 0:
logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.")
return build_standard_mesh()

# We only break the first device axis (the least communication intensive) across granules.
assert (
ici_mesh_shape[0] % num_granules == 0
), "First mesh shape axis must divide num slices/granules."
logging.info("Building multi-slice/granule device mesh.")
# Truncate intra-slice/granule mesh.
ici_mesh_shape = (ici_mesh_shape[0] // num_granules, *ici_mesh_shape[1:])
logging.info("Inferred intra-slice/granule mesh shape: %s", ici_mesh_shape)
# Configure data center (inter-slice/granule) mesh.
dcn_mesh_shape = (num_granules,) + (1,) * len(ici_mesh_shape[1:])
logging.info("Inferred inter-slice/granule mesh shape: %s", dcn_mesh_shape)
# Check we have the right number of devices.
total_parallelism = np.product(dcn_mesh_shape) * np.product(ici_mesh_shape)
assert total_parallelism == len(devices), (
f"Num devices {len(devices)} does not match the product of "
f"inter and intra slice/granule parallelism {total_parallelism}."
)
return mesh_utils.create_hybrid_device_mesh(
ici_mesh_shape,
dcn_mesh_shape=dcn_mesh_shape,
devices=devices,
process_is_granule=attr == "process_index",
)
Loading

0 comments on commit 70eb15f

Please sign in to comment.