Skip to content

Commit

Permalink
Allow passing planner to _shard_modules (pytorch#2732)
Browse files Browse the repository at this point in the history
Summary:

`_shard_modules` function is used in fx_traceability tests for SDD and SemiSync pipeline. It uses a default ShardingPlanner and topology that use hardcoded batch size (512) and HBM memory limit (32Gb), respectively. This change allows specifying the ShardingPlanner and Topology to more accurately reflect the machine capabilities. The change is intentionally limited to `_shard_modules` only and not public `shard_modules` to avoid changing the contract for the latter.

Reviewed By: sarckk

Differential Revision: D69163227
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 12, 2025
1 parent 1afbf08 commit f269be7
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _shard_modules( # noqa: C901
plan: Optional[ShardingPlan] = None,
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
init_params: Optional[bool] = False,
planner: Optional[EmbeddingShardingPlanner] = None,
) -> nn.Module:
"""
See shard_modules
Expand Down Expand Up @@ -238,13 +239,14 @@ def _shard_modules( # noqa: C901
assert isinstance(
env, ShardingEnv
), "Currently hybrid sharding only support use manual sharding plan"
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(env.world_size),
world_size=env.world_size,
compute_device=device.type,
if planner is None:
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(env.world_size),
world_size=env.world_size,
compute_device=device.type,
)
)
)
pg = env.process_group
if pg is not None:
plan = planner.collective_plan(module, sharders, pg)
Expand Down

0 comments on commit f269be7

Please sign in to comment.