Skip to content

Commit

Permalink
Merge pull request #153 from nuaalixu/main
Browse files Browse the repository at this point in the history
fix #92 for fsdp training
  • Loading branch information
ddlBoJack authored Oct 15, 2024
2 parents 38d8c66 + 832bf02 commit 0045773
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 10 deletions.
6 changes: 5 additions & 1 deletion examples/aac_audiocaps/aac_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory"
Expand Down Expand Up @@ -114,7 +118,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/asr_librispeech/asr_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory"
Expand Down Expand Up @@ -108,7 +112,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/drcap_zeroshot_aac/drcap_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/drcap_zeroshot_aac/model/slam_model_drcap.py:model_factory"
Expand Down Expand Up @@ -113,7 +117,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
5 changes: 4 additions & 1 deletion examples/mala_asr_slidespeech/mala_asr_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass, field
from typing import Optional, List
from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory"
Expand Down Expand Up @@ -109,7 +112,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
5 changes: 4 additions & 1 deletion examples/mc_musiccaps/mir_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy

@dataclass
class ModelConfig:
file: str = "examples/mc_musiccaps/model/slam_model_mir.py:model_factory"
Expand Down Expand Up @@ -112,7 +115,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
5 changes: 4 additions & 1 deletion examples/seld_spatialsoundqa/seld_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/seld_spatialsoundqa/model/slam_model_seld.py:model_factory"
Expand Down Expand Up @@ -97,7 +100,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/vallex/vallex_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
llm_name: str = "vallex"
Expand Down Expand Up @@ -68,7 +72,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/vsr_LRS3/vsr_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/vsr_LRS3/model/slam_model_vsr.py:model_factory"
Expand Down Expand Up @@ -115,7 +119,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
4 changes: 2 additions & 2 deletions src/slam_llm/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def main(kwargs: DictConfig):
if not train_config.use_peft and train_config.freeze_layers:

freeze_transformer_layers(train_config.num_freeze_layers)
from torch.distributed.fsdp import ShardingStrategy
fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)
# from torch.distributed.fsdp import ShardingStrategy
# fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy)
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

Expand Down

0 comments on commit 0045773

Please sign in to comment.