diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000..105ce2da2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000..1d3ce46ba0 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/torchtune.iml b/.idea/torchtune.iml new file mode 100644 index 0000000000..aad402c4e5 --- /dev/null +++ b/.idea/torchtune.iml @@ -0,0 +1,10 @@ + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000..35eb1ddfbb --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000000..c26120b4a0 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + { + "associatedIndex": 2 +} + + + + { + "keyToString": { + "RunOnceActivity.ShowReadmeOnStart": "true", + "git-widget-placeholder": "main" + } +} + + + + + + + + + + + 1732727940405 + + + + \ No newline at end of file diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index ef238da44d..8ced90014d 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -31,7 +31,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training import DummyProfiler, PROFILER_KEY, OffloadActivations, NoOpManager from tqdm import tqdm @@ -120,6 +120,11 @@ def __init__(self, cfg: DictConfig) -> None: self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + self._enable_activation_offloading=cfg.get("enable_activation_offloading", False) + + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError("Activation offloading requires a CUDA device.") if self._log_peak_memory_stats and self._device.type != "cuda": log.info( @@ -235,6 +240,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -244,9 +250,12 @@ def setup(self, cfg: DictConfig) -> None: ), ) + + self._teacher_model = self._setup_teacher_model( model_cfg=cfg.teacher_model, model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], + enable_activation_offloading=self._enable_activation_offloading ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -391,6 +400,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, @@ -443,6 +453,26 @@ def _setup_model( training.validate_expected_param_dtype( self.adapter_params.items(), dtype=self._dtype ) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # if enable_activation_offloading: + # self.activations_handling_ctx = OffloadActivations() + # + # # Below is our hack to disable offloading the last output Linear in every + # # step, as the cost for offloading the activation and then soon after bringing + # # it back is expensive. Moreover, due to heuristics in our streaming API, + # # we actually use more memory if we offload it as it interferes with chunkedCE. + # if hasattr(model, "output") and isinstance(model.output, nn.Module): + # noop_ctx = NoOpManager() + # model.output.register_forward_pre_hook( + # lambda *args: noop_ctx.__enter__() + # ) + # model.output.register_forward_hook( + # lambda *args: noop_ctx.__exit__(), always_call=True + # ) log.info(f"Student model is initialized with precision {self._dtype}.") @@ -457,7 +487,7 @@ def _setup_model( def _setup_teacher_model( self, model_cfg: DictConfig, - model_state_dict: Dict[str, Any], + model_state_dict: Dict[str, Any] ) -> nn.Module: with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(model_cfg) @@ -662,7 +692,7 @@ def _loss_step( def train(self) -> None: """ - The core training loop. + The core training loop with WandB integration. """ if self._compile: @@ -676,6 +706,10 @@ def train(self) -> None: running_kd_loss = 0 num_tokens = 0 + if self.use_wandb: + import wandb + wandb.init(project="your_project_name", config=self._cfg) + with self._profiler as prof: # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -686,17 +720,17 @@ def train(self) -> None: pbar = tqdm(total=self._steps_per_epoch) for idx, batch in enumerate(self._dataloader): if ( - self.max_steps_per_epoch is not None - and (idx // self._gradient_accumulation_steps) - == self.max_steps_per_epoch + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch ): break # Start tracking CUDA memory for active steps for just the first epoch if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps ): torch.cuda.memory._record_memory_history() @@ -705,7 +739,7 @@ def train(self) -> None: # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index + batch["labels"] != self._loss_fn.ignore_index ).sum() num_tokens += current_num_tokens @@ -713,8 +747,8 @@ def train(self) -> None: running_class_loss += class_loss * current_num_tokens running_kd_loss += kd_loss * current_num_tokens current_loss = ( - 1 - self._kd_ratio - ) * class_loss + self._kd_ratio * kd_loss + 1 - self._kd_ratio + ) * class_loss + self._kd_ratio * kd_loss current_loss.backward() # Step with optimizer @@ -734,8 +768,8 @@ def train(self) -> None: class_loss_to_log = running_class_loss.item() / num_tokens kd_loss_to_log = running_kd_loss.item() / num_tokens loss_to_log = ( - 1 - self._kd_ratio - ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log + 1 - self._kd_ratio + ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" @@ -752,14 +786,17 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if ( - self._device.type == "cuda" - and self._log_peak_memory_stats + self._device.type == "cuda" + and self._log_peak_memory_stats ): log_dict.update( training.get_memory_stats(device=self._device) ) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) + if self.use_wandb: + wandb.log(log_dict) + self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -773,23 +810,24 @@ def train(self) -> None: # Stop tracking CUDA memory now that active steps are complete if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps ): torch.cuda.memory._record_memory_history(enabled=None) # Step the profiler - # Note we are stepping each batch, which might not include optimizer step in the trace - # if the schedule cycle doesn't align with gradient accumulation. prof.step() self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) + if self.use_wandb: + wandb.finish() + def cleanup(self) -> None: self._metric_logger.close() diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py index 713e05c98f..1424cdc45e 100644 --- a/tests/recipes/test_knowledge_distillation_single_device.py +++ b/tests/recipes/test_knowledge_distillation_single_device.py @@ -35,7 +35,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "device=cpu", f"dtype={dtype_str}", "enable_activation_checkpointing=False", - "enable_activation_offloading=False", + "enable_activation_offloading=True", "dataset.train_on_input=False", "seed=9", f"epochs={epochs}",