From 832bf02e65f07e324294dd75c5ab03e135dc5c54 Mon Sep 17 00:00:00 2001 From: "xu.li" Date: Mon, 14 Oct 2024 11:41:38 +0800 Subject: [PATCH] fix #92 for fsdp training --- examples/aac_audiocaps/aac_config.py | 6 +++++- examples/asr_librispeech/asr_config.py | 6 +++++- examples/drcap_zeroshot_aac/drcap_config.py | 6 +++++- examples/mala_asr_slidespeech/mala_asr_config.py | 5 ++++- examples/mc_musiccaps/mir_config.py | 5 ++++- examples/seld_spatialsoundqa/seld_config.py | 5 ++++- examples/vallex/vallex_config.py | 6 +++++- examples/vsr_LRS3/vsr_config.py | 6 +++++- src/slam_llm/pipeline/finetune.py | 4 ++-- 9 files changed, 39 insertions(+), 10 deletions(-) diff --git a/examples/aac_audiocaps/aac_config.py b/examples/aac_audiocaps/aac_config.py index 397abc13..0798c87d 100644 --- a/examples/aac_audiocaps/aac_config.py +++ b/examples/aac_audiocaps/aac_config.py @@ -1,5 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: file: str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory" @@ -114,7 +118,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/asr_librispeech/asr_config.py b/examples/asr_librispeech/asr_config.py index 280280dc..d6683157 100644 --- a/examples/asr_librispeech/asr_config.py +++ b/examples/asr_librispeech/asr_config.py @@ -1,5 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: file: str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory" @@ -108,7 +112,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/drcap_zeroshot_aac/drcap_config.py b/examples/drcap_zeroshot_aac/drcap_config.py index b4f4956f..328f0c23 100644 --- a/examples/drcap_zeroshot_aac/drcap_config.py +++ b/examples/drcap_zeroshot_aac/drcap_config.py @@ -1,5 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: file: str = "examples/drcap_zeroshot_aac/model/slam_model_drcap.py:model_factory" @@ -113,7 +117,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/mala_asr_slidespeech/mala_asr_config.py b/examples/mala_asr_slidespeech/mala_asr_config.py index 7c8e2794..ba98ec78 100644 --- a/examples/mala_asr_slidespeech/mala_asr_config.py +++ b/examples/mala_asr_slidespeech/mala_asr_config.py @@ -1,5 +1,8 @@ from dataclasses import dataclass, field from typing import Optional, List +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: file: str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory" @@ -109,7 +112,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/mc_musiccaps/mir_config.py b/examples/mc_musiccaps/mir_config.py index a1279b12..e73fd279 100644 --- a/examples/mc_musiccaps/mir_config.py +++ b/examples/mc_musiccaps/mir_config.py @@ -1,5 +1,8 @@ from dataclasses import dataclass, field from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + @dataclass class ModelConfig: file: str = "examples/mc_musiccaps/model/slam_model_mir.py:model_factory" @@ -112,7 +115,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/seld_spatialsoundqa/seld_config.py b/examples/seld_spatialsoundqa/seld_config.py index 439e5819..ce69f27f 100644 --- a/examples/seld_spatialsoundqa/seld_config.py +++ b/examples/seld_spatialsoundqa/seld_config.py @@ -1,6 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, List +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: file: str = "examples/seld_spatialsoundqa/model/slam_model_seld.py:model_factory" @@ -97,7 +100,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/vallex/vallex_config.py b/examples/vallex/vallex_config.py index a4cd350a..e03201ec 100644 --- a/examples/vallex/vallex_config.py +++ b/examples/vallex/vallex_config.py @@ -1,5 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: llm_name: str = "vallex" @@ -68,7 +72,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/examples/vsr_LRS3/vsr_config.py b/examples/vsr_LRS3/vsr_config.py index 3c051a01..0a271eef 100644 --- a/examples/vsr_LRS3/vsr_config.py +++ b/examples/vsr_LRS3/vsr_config.py @@ -1,5 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, List + +from torch.distributed.fsdp import ShardingStrategy + + @dataclass class ModelConfig: file: str = "examples/vsr_LRS3/model/slam_model_vsr.py:model_factory" @@ -115,7 +119,7 @@ class FSDPConfig: mixed_precision: bool = True use_fp16: bool = False # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD - sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. fsdp_activation_checkpointing: bool = True fsdp_cpu_offload: bool = False diff --git a/src/slam_llm/pipeline/finetune.py b/src/slam_llm/pipeline/finetune.py index f9132263..4ced3c51 100644 --- a/src/slam_llm/pipeline/finetune.py +++ b/src/slam_llm/pipeline/finetune.py @@ -159,8 +159,8 @@ def main(kwargs: DictConfig): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(train_config.num_freeze_layers) - from torch.distributed.fsdp import ShardingStrategy - fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy) + # from torch.distributed.fsdp import ShardingStrategy + # fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)