diff --git a/3.test_cases/10.FSDP/train.py b/3.test_cases/10.FSDP/train.py index bcd56057..002a08c2 100644 --- a/3.test_cases/10.FSDP/train.py +++ b/3.test_cases/10.FSDP/train.py @@ -155,7 +155,13 @@ def main(args): logger.info( "Creating Model" ) - model = AutoModelForCausalLM.from_config(model_config) + # Instantiate model on CPU on rank=0 only to prevent CPU OOM + # (e.g. 70B * 4 bytes * 8 processes > 2T RAM available on P5) + if global_rank == 0: + model = AutoModelForCausalLM.from_config(model_config) + else: + with torch.device("meta"): + model = AutoModelForCausalLM.from_config(model_config) num_params = compute_num_params(model) if global_rank == 0: @@ -191,6 +197,8 @@ def main(args): device_id=torch.cuda.current_device(), use_orig_params=False, sharding_strategy=sharding_strategy, + param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)) + if global_rank != 0 else None, ) if global_rank == 0: