diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index a755d2c8b..0a27711a7 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -204,6 +204,7 @@ def _shard_modules( # noqa: C901 plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, init_params: Optional[bool] = False, + planner: Optional[EmbeddingShardingPlanner] = None, ) -> nn.Module: """ See shard_modules @@ -238,13 +239,14 @@ def _shard_modules( # noqa: C901 assert isinstance( env, ShardingEnv ), "Currently hybrid sharding only support use manual sharding plan" - planner = EmbeddingShardingPlanner( - topology=Topology( - local_world_size=get_local_size(env.world_size), - world_size=env.world_size, - compute_device=device.type, + if planner is None: + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=get_local_size(env.world_size), + world_size=env.world_size, + compute_device=device.type, + ) ) - ) pg = env.process_group if pg is not None: plan = planner.collective_plan(module, sharders, pg)