Skip to content

Commit

Permalink
format codes
Browse files Browse the repository at this point in the history
  • Loading branch information
fclearner committed Aug 4, 2024
1 parent 4c1f2da commit 4d89439
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
6 changes: 3 additions & 3 deletions wenet/finetune/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def inject_lora(module, lora_config):
if hasattr(module, lora_attr):
submodule = getattr(module, lora_attr)
n_feat = submodule.in_features
lora_linear = lora.Linear(n_feat, n_feat, r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout)
lora_linear = lora.Linear(n_feat, n_feat, r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout)
setattr(module, lora_attr, lora_linear)


Expand Down
27 changes: 16 additions & 11 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def add_dataset_args(parser):


def add_lora_args(parser):
'''Configure parameters for LoRA fine-tuning. Set use_lora and only_optimize_lora to true to enable LoRA functionality.
LoRA will be injected to model through lora_modules, lora_attn_attr, lora_list.
LoRA weights will be merged after calling model.eval() (or model.train(mode=False)).
'''Configure parameters for LoRA fine-tuning. Set use_lora and
only_optimize_lora to true to enable LoRA functionality.
LoRA will be injected to model through (lora_modules, lora_attn_attr,
lora_list).
LoRA weights will be merged after calling model.eval()
(or model.train(mode=False)).
LoRA weights need to be loaded after fine-tuning with DeepSpeed.
'''
parser.add_argument("--use_lora",
Expand All @@ -136,14 +139,16 @@ def add_lora_args(parser):
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='modules names needs inject lora',
)
parser.add_argument("--lora_attn_attr",
default="self_attn,src_attn",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="lora_attn_attr.")
parser.add_argument("--lora_list",
default="linear_out,linear_q,linear_k,linear_v",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="lora module list.")
parser.add_argument(
"--lora_attn_attr",
default="self_attn,src_attn",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="lora_attn_attr.")
parser.add_argument(
"--lora_list",
default="linear_out,linear_q,linear_k,linear_v",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="lora module list.")
parser.add_argument("--lora_rank",
default=8,
type=int,
Expand Down

0 comments on commit 4d89439

Please sign in to comment.