Skip to content

Commit

Permalink
fix deepspeed dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
zzasdf committed May 31, 2024
1 parent c930354 commit a77c88a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 14 additions & 0 deletions src/slam_llm/pipeline/finetune_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ def main(kwargs: DictConfig):
parameters = filter(lambda p: p.requires_grad, model.parameters())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# If you are facing problem from limited memory(<=256GB), you can try to replace the above code with the following code
# for i in range(rank):
# while not os.path.isfile(f".{i}.done"):
# pass
# assert not os.path.isfile(f".{rank}.done"), f".{rank}.done already exists!"
# model_factory = get_custom_model_factory(model_config, logger)
# model, tokenizer = model_factory(train_config, model_config, **kwargs)
# parameters = filter(lambda p: p.requires_grad, model.parameters())
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.half()
# with open(f".{rank}.done", "w"):
# pass


# Initialize the optimizer and learning rate scheduler
model_engine, _, _, _ = deepspeed.initialize(
model=model, model_parameters=parameters, config=deepspeed_config
Expand Down
6 changes: 3 additions & 3 deletions src/slam_llm/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
if train_config.batching_strategy == "padding":
if train_config.enable_fsdp or train_config.enable_ddp:
if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
Expand All @@ -81,7 +81,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
elif train_config.batching_strategy == "packing":
if train_config.enable_fsdp or train_config.enable_ddp:
if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
Expand All @@ -93,7 +93,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs["collate_fn"] = default_data_collator
else:
# raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
if train_config.enable_fsdp or train_config.enable_ddp:
if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
Expand Down

0 comments on commit a77c88a

Please sign in to comment.