Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow passing planner to _shard_modules (pytorch#2732)
Summary: `_shard_modules` function is used in fx_traceability tests for SDD and SemiSync pipeline. It uses a default ShardingPlanner and topology that use hardcoded batch size (512) and HBM memory limit (32Gb), respectively. This change allows specifying the ShardingPlanner and Topology to more accurately reflect the machine capabilities. The change is intentionally limited to `_shard_modules` only and not public `shard_modules` to avoid changing the contract for the latter. Reviewed By: sarckk Differential Revision: D69163227
- Loading branch information