diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000..105ce2da2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000..1d3ce46ba0
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/torchtune.iml b/.idea/torchtune.iml
new file mode 100644
index 0000000000..aad402c4e5
--- /dev/null
+++ b/.idea/torchtune.iml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000..35eb1ddfbb
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000000..c26120b4a0
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,57 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "associatedIndex": 2
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "git-widget-placeholder": "main"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+ 1732727940405
+
+
+ 1732727940405
+
+
+
+
\ No newline at end of file
diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py
index ef238da44d..8ced90014d 100644
--- a/recipes/knowledge_distillation_single_device.py
+++ b/recipes/knowledge_distillation_single_device.py
@@ -31,7 +31,7 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
-from torchtune.training import DummyProfiler, PROFILER_KEY
+from torchtune.training import DummyProfiler, PROFILER_KEY, OffloadActivations, NoOpManager
from tqdm import tqdm
@@ -120,6 +120,11 @@ def __init__(self, cfg: DictConfig) -> None:
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
+ self._enable_activation_offloading=cfg.get("enable_activation_offloading", False)
+
+ if self._enable_activation_offloading:
+ if self._device.type != "cuda":
+ raise RuntimeError("Activation offloading requires a CUDA device.")
if self._log_peak_memory_stats and self._device.type != "cuda":
log.info(
@@ -235,6 +240,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
+ enable_activation_offloading=self._enable_activation_offloading,
compile_model=cfg.compile,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
@@ -244,9 +250,12 @@ def setup(self, cfg: DictConfig) -> None:
),
)
+
+
self._teacher_model = self._setup_teacher_model(
model_cfg=cfg.teacher_model,
model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY],
+ enable_activation_offloading=self._enable_activation_offloading
)
self._tokenizer = config.instantiate(cfg.tokenizer)
@@ -391,6 +400,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
+ enable_activation_offloading: bool,
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
@@ -443,6 +453,26 @@ def _setup_model(
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
+ # activation offloading
+ self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+ model, enable_activation_offloading
+ )
+
+ # if enable_activation_offloading:
+ # self.activations_handling_ctx = OffloadActivations()
+ #
+ # # Below is our hack to disable offloading the last output Linear in every
+ # # step, as the cost for offloading the activation and then soon after bringing
+ # # it back is expensive. Moreover, due to heuristics in our streaming API,
+ # # we actually use more memory if we offload it as it interferes with chunkedCE.
+ # if hasattr(model, "output") and isinstance(model.output, nn.Module):
+ # noop_ctx = NoOpManager()
+ # model.output.register_forward_pre_hook(
+ # lambda *args: noop_ctx.__enter__()
+ # )
+ # model.output.register_forward_hook(
+ # lambda *args: noop_ctx.__exit__(), always_call=True
+ # )
log.info(f"Student model is initialized with precision {self._dtype}.")
@@ -457,7 +487,7 @@ def _setup_model(
def _setup_teacher_model(
self,
model_cfg: DictConfig,
- model_state_dict: Dict[str, Any],
+ model_state_dict: Dict[str, Any]
) -> nn.Module:
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
@@ -662,7 +692,7 @@ def _loss_step(
def train(self) -> None:
"""
- The core training loop.
+ The core training loop with WandB integration.
"""
if self._compile:
@@ -676,6 +706,10 @@ def train(self) -> None:
running_kd_loss = 0
num_tokens = 0
+ if self.use_wandb:
+ import wandb
+ wandb.init(project="your_project_name", config=self._cfg)
+
with self._profiler as prof:
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
@@ -686,17 +720,17 @@ def train(self) -> None:
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader):
if (
- self.max_steps_per_epoch is not None
- and (idx // self._gradient_accumulation_steps)
- == self.max_steps_per_epoch
+ self.max_steps_per_epoch is not None
+ and (idx // self._gradient_accumulation_steps)
+ == self.max_steps_per_epoch
):
break
# Start tracking CUDA memory for active steps for just the first epoch
if (
- curr_epoch == 0
- and self.profiler_profile_memory
- and idx == self.profiler_wait_steps + self.profiler_warmup_steps
+ curr_epoch == 0
+ and self.profiler_profile_memory
+ and idx == self.profiler_wait_steps + self.profiler_warmup_steps
):
torch.cuda.memory._record_memory_history()
@@ -705,7 +739,7 @@ def train(self) -> None:
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
- batch["labels"] != self._loss_fn.ignore_index
+ batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
@@ -713,8 +747,8 @@ def train(self) -> None:
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
- 1 - self._kd_ratio
- ) * class_loss + self._kd_ratio * kd_loss
+ 1 - self._kd_ratio
+ ) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()
# Step with optimizer
@@ -734,8 +768,8 @@ def train(self) -> None:
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
loss_to_log = (
- 1 - self._kd_ratio
- ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
+ 1 - self._kd_ratio
+ ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
@@ -752,14 +786,17 @@ def train(self) -> None:
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if (
- self._device.type == "cuda"
- and self._log_peak_memory_stats
+ self._device.type == "cuda"
+ and self._log_peak_memory_stats
):
log_dict.update(
training.get_memory_stats(device=self._device)
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
+ if self.use_wandb:
+ wandb.log(log_dict)
+
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
@@ -773,23 +810,24 @@ def train(self) -> None:
# Stop tracking CUDA memory now that active steps are complete
if (
- curr_epoch == 0
- and self.profiler_profile_memory
- and idx
- == self.profiler_wait_steps
- + self.profiler_warmup_steps
- + self.profiler_active_steps
+ curr_epoch == 0
+ and self.profiler_profile_memory
+ and idx
+ == self.profiler_wait_steps
+ + self.profiler_warmup_steps
+ + self.profiler_active_steps
):
torch.cuda.memory._record_memory_history(enabled=None)
# Step the profiler
- # Note we are stepping each batch, which might not include optimizer step in the trace
- # if the schedule cycle doesn't align with gradient accumulation.
prof.step()
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
+ if self.use_wandb:
+ wandb.finish()
+
def cleanup(self) -> None:
self._metric_logger.close()
diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py
index 713e05c98f..1424cdc45e 100644
--- a/tests/recipes/test_knowledge_distillation_single_device.py
+++ b/tests/recipes/test_knowledge_distillation_single_device.py
@@ -35,7 +35,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
- "enable_activation_offloading=False",
+ "enable_activation_offloading=True",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",