diff --git a/src/slam_llm/pipeline/finetune_deepspeed.py b/src/slam_llm/pipeline/finetune_deepspeed.py index 1dd8e3f7..8f275faf 100644 --- a/src/slam_llm/pipeline/finetune_deepspeed.py +++ b/src/slam_llm/pipeline/finetune_deepspeed.py @@ -145,6 +145,20 @@ def main(kwargs: DictConfig): parameters = filter(lambda p: p.requires_grad, model.parameters()) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # If you are facing problem from limited memory(<=256GB), you can try to replace the above code with the following code + # for i in range(rank): + # while not os.path.isfile(f".{i}.done"): + # pass + # assert not os.path.isfile(f".{rank}.done"), f".{rank}.done already exists!" + # model_factory = get_custom_model_factory(model_config, logger) + # model, tokenizer = model_factory(train_config, model_config, **kwargs) + # parameters = filter(lambda p: p.requires_grad, model.parameters()) + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # model.half() + # with open(f".{rank}.done", "w"): + # pass + + # Initialize the optimizer and learning rate scheduler model_engine, _, _, _ = deepspeed.initialize( model=model, model_parameters=parameters, config=deepspeed_config diff --git a/src/slam_llm/utils/config_utils.py b/src/slam_llm/utils/config_utils.py index 743b3734..b0aadf7d 100644 --- a/src/slam_llm/utils/config_utils.py +++ b/src/slam_llm/utils/config_utils.py @@ -69,7 +69,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size if train_config.batching_strategy == "padding": - if train_config.enable_fsdp or train_config.enable_ddp: + if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( dataset, batch_size=batch_size, @@ -81,7 +81,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) elif train_config.batching_strategy == "packing": - if train_config.enable_fsdp or train_config.enable_ddp: + if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: kwargs["sampler"] = DistributedSampler( dataset, rank=dist.get_rank(), @@ -93,7 +93,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["collate_fn"] = default_data_collator else: # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") - if train_config.enable_fsdp or train_config.enable_ddp: + if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: kwargs["sampler"] = DistributedSampler( dataset, rank=dist.get_rank(),