Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement activation offloading and opt_in_bwd in knowledge_distillation recipes #2088

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions .idea/torchtune.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 57 additions & 0 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 62 additions & 24 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training import DummyProfiler, PROFILER_KEY, OffloadActivations, NoOpManager

from tqdm import tqdm

Expand Down Expand Up @@ -120,6 +120,11 @@ def __init__(self, cfg: DictConfig) -> None:
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
self._enable_activation_offloading=cfg.get("enable_activation_offloading", False)

if self._enable_activation_offloading:
if self._device.type != "cuda":
raise RuntimeError("Activation offloading requires a CUDA device.")

if self._log_peak_memory_stats and self._device.type != "cuda":
log.info(
Expand Down Expand Up @@ -235,6 +240,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
compile_model=cfg.compile,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
Expand All @@ -244,9 +250,12 @@ def setup(self, cfg: DictConfig) -> None:
),
)



self._teacher_model = self._setup_teacher_model(
model_cfg=cfg.teacher_model,
model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY],
enable_activation_offloading=self._enable_activation_offloading
)

self._tokenizer = config.instantiate(cfg.tokenizer)
Expand Down Expand Up @@ -391,6 +400,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -443,6 +453,26 @@ def _setup_model(
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)

# if enable_activation_offloading:
# self.activations_handling_ctx = OffloadActivations()
#
# # Below is our hack to disable offloading the last output Linear in every
# # step, as the cost for offloading the activation and then soon after bringing
# # it back is expensive. Moreover, due to heuristics in our streaming API,
# # we actually use more memory if we offload it as it interferes with chunkedCE.
# if hasattr(model, "output") and isinstance(model.output, nn.Module):
# noop_ctx = NoOpManager()
# model.output.register_forward_pre_hook(
# lambda *args: noop_ctx.__enter__()
# )
# model.output.register_forward_hook(
# lambda *args: noop_ctx.__exit__(), always_call=True
# )

log.info(f"Student model is initialized with precision {self._dtype}.")

Expand All @@ -457,7 +487,7 @@ def _setup_model(
def _setup_teacher_model(
self,
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
model_state_dict: Dict[str, Any]
) -> nn.Module:
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
Expand Down Expand Up @@ -662,7 +692,7 @@ def _loss_step(

def train(self) -> None:
"""
The core training loop.
The core training loop with WandB integration.
"""

if self._compile:
Expand All @@ -676,6 +706,10 @@ def train(self) -> None:
running_kd_loss = 0
num_tokens = 0

if self.use_wandb:
import wandb
wandb.init(project="your_project_name", config=self._cfg)

with self._profiler as prof:
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
Expand All @@ -686,17 +720,17 @@ def train(self) -> None:
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break

# Start tracking CUDA memory for active steps for just the first epoch
if (
curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
):
torch.cuda.memory._record_memory_history()

Expand All @@ -705,16 +739,16 @@ def train(self) -> None:
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

class_loss, kd_loss = self._loss_step(batch)
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()

# Step with optimizer
Expand All @@ -734,8 +768,8 @@ def train(self) -> None:
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -752,14 +786,17 @@ def train(self) -> None:
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if (
self._device.type == "cuda"
and self._log_peak_memory_stats
self._device.type == "cuda"
and self._log_peak_memory_stats
):
log_dict.update(
training.get_memory_stats(device=self._device)
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
if self.use_wandb:
wandb.log(log_dict)

self._metric_logger.log_dict(
log_dict,
step=self.global_step,
Expand All @@ -773,23 +810,24 @@ def train(self) -> None:

# Stop tracking CUDA memory now that active steps are complete
if (
curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step the profiler
# Note we are stepping each batch, which might not include optimizer step in the trace
# if the schedule cycle doesn't align with gradient accumulation.
prof.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

if self.use_wandb:
wandb.finish()

def cleanup(self) -> None:
self._metric_logger.close()

Expand Down
2 changes: 1 addition & 1 deletion tests/recipes/test_knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
"enable_activation_offloading=False",
"enable_activation_offloading=True",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",
Expand Down