Skip to content

Commit

Permalink
[optim] add support to APOLLO (#6617)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuhanqing authored Jan 14, 2025
1 parent 9b7ba09 commit d9189f9
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 5 deletions.
45 changes: 45 additions & 0 deletions examples/extras/apollo/llama3_full_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true

### method
stage: sft
do_train: true
finetuning_type: full
use_apollo: true
apollo_layerwise: true
apollo_target: mlp,self_attn
apollo_rank: 128
apollo_scale: 32.0
apollo_scale_type: channel

### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/llama3-8b/apollo_full-scale32/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
pure_bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_console_scripts() -> List[str]:
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.5"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/extras/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def is_fastapi_available():
def is_galore_available():
return _is_package_available("galore_torch")

def is_apollo_available():
return _is_package_available("apollo_torch")

def is_gradio_available():
return _is_package_available("gradio")
Expand Down
64 changes: 62 additions & 2 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,59 @@ class GaloreArguments:
)


@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""

use_apollo: bool = field(
default=False,
metadata={"help": "Whether or not to use the APOLLO optimizer."},
)
apollo_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
apollo_rank: int = field(
default=16,
metadata={"help": "The rank of APOLLO gradients."},
)
apollo_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the APOLLO projection."},
)
apollo_scale: float = field(
default=1.0,
metadata={"help": "APOLLO scaling coefficient."},
)
apollo_proj: Literal["svd", "random"] = field(
default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
)
apollo_proj_type: Literal["std", "right", "left",] = field(
default="std",
metadata={"help": "Type of APOLLO projection."},
)
apollo_scale_type: Literal["channel", "tensor"] = field(
default="channel",
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
)
apollo_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
apollo_scale_front: bool = field(
default=False,
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
)


@dataclass
class BAdamArgument:
r"""
Expand Down Expand Up @@ -334,7 +387,7 @@ class SwanLabArguments:

@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Expand Down Expand Up @@ -401,6 +454,7 @@ def split_arg(arg):
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
Expand All @@ -421,12 +475,18 @@ def split_arg(arg):
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")

if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")

if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")

if self.use_galore and self.use_apollo:
raise ValueError("Cannot use GaLore with APOLLO together.")

if self.use_badam and self.use_apollo:
raise ValueError("Cannot use BAdam with APOLLO together.")

if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")

Expand Down
18 changes: 18 additions & 0 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def _check_extra_dependencies(
if finetuning_args.use_galore:
check_version("galore_torch", mandatory=True)

if finetuning_args.use_apollo:
check_version("apollo_torch", mandatory=True)

if finetuning_args.use_badam:
check_version("badam>=1.2.1", mandatory=True)

Expand Down Expand Up @@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
):
raise ValueError("Distributed training does not support layer-wise GaLore.")

if (
finetuning_args.use_apollo
and finetuning_args.apollo_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise APOLLO.")

if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
Expand All @@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")

if finetuning_args.use_apollo and training_args.deepspeed is not None:
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")

if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")

Expand Down Expand Up @@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)

if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
logger.warning_rank0(
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
)

if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")

Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/model/model_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore.
Finds all available modules to apply lora or galore or apollo.
"""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
Expand Down
93 changes: 91 additions & 2 deletions src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@

from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available, is_ray_available
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params


if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore

if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore

if is_ray_available():
from ray.train import RunConfig, ScalingConfig
Expand All @@ -58,7 +60,7 @@

class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore algorithm.
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""

def __init__(
Expand Down Expand Up @@ -275,6 +277,90 @@ def optimizer_hook(param: "torch.nn.Parameter"):
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer

def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
apollo_targets = finetuning_args.apollo_target

apollo_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
apollo_params.append(param)

apollo_kwargs = {
"rank": finetuning_args.apollo_rank,
"proj": finetuning_args.apollo_proj,
"proj_type": finetuning_args.apollo_proj_type,
"update_proj_gap": finetuning_args.apollo_update_interval,
"scale": finetuning_args.apollo_scale,
"scale_type": finetuning_args.apollo_scale_type,
"scale_front": finetuning_args.apollo_scale_front,
}

print(apollo_kwargs)

id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_apollo_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)

_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)

if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")

if finetuning_args.apollo_layerwise:
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")

optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # galore params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)

def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()

for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)

optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)

logger.info_rank0("Using APOLLO optimizer.")
return optimizer


def _create_loraplus_optimizer(
model: "PreTrainedModel",
Expand Down Expand Up @@ -410,6 +496,9 @@ def create_custom_optimizer(
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)

if finetuning_args.use_apollo:
return _create_apollo_optimizer(model, training_args, finetuning_args)

if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)

Expand Down
19 changes: 19 additions & 0 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)

with gr.Accordion(open=False) as apollo_tab:
with gr.Row():
use_apollo = gr.Checkbox()
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
apollo_target = gr.Textbox(value="all")
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
elem_dict.update(
dict(
apollo_tab=apollo_tab,
use_apollo=use_apollo,
apollo_rank=apollo_rank,
apollo_update_interval=apollo_update_interval,
apollo_scale=apollo_scale,
apollo_target=apollo_target,
)
)

with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
Expand Down
Loading

0 comments on commit d9189f9

Please sign in to comment.