From 288427e0eecf2060673195042efbd0d00f800d66 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 5 Jul 2024 16:25:09 -0400 Subject: [PATCH] m --- kronfluence/computer/factor_computer.py | 10 ++++- kronfluence/factor/covariance.py | 3 +- kronfluence/factor/eigen.py | 3 +- kronfluence/module/tracked_module.py | 7 ++- kronfluence/module/tracker/base.py | 6 +++ kronfluence/module/tracker/factor.py | 57 ++++++++++++++----------- 6 files changed, 54 insertions(+), 32 deletions(-) diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py index 32145b9..2fb12e2 100644 --- a/kronfluence/computer/factor_computer.py +++ b/kronfluence/computer/factor_computer.py @@ -297,6 +297,7 @@ def fit_covariance_matrices( total_data_examples=max_partition_examples, ) + self._reset_memory() start_time = get_time(state=self.state) with self.profiler.profile("Fit Covariance"): loader = self._get_dataloader( @@ -331,8 +332,9 @@ def fit_covariance_matrices( metadata=factor_args.to_str_dict(), ) self.state.wait_for_everyone() - del covariance_factors, loader self.logger.info(f"Saved covariance matrices at `{factors_output_dir}`.") + del num_data_processed, covariance_factors, loader + self._reset_memory() all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time @@ -442,6 +444,7 @@ def perform_eigendecomposition( ) self.state.wait_for_everyone() + self._reset_memory() eigen_factors = None if self.state.is_main_process: start_time = time.time() @@ -462,6 +465,7 @@ def perform_eigendecomposition( output_dir=factors_output_dir, factors=eigen_factors, metadata=factor_args.to_str_dict() ) self.logger.info(f"Saved eigendecomposition results at `{factors_output_dir}`.") + self._reset_memory() self.state.wait_for_everyone() self._log_profile_summary(name=f"factors_{factors_name}_eigendecomposition") @@ -645,6 +649,7 @@ def fit_lambda_matrices( total_data_examples=max_partition_examples, ) + self._reset_memory() start_time = get_time(state=self.state) with self.profiler.profile("Fit Lambda"): loader = self._get_dataloader( @@ -680,8 +685,9 @@ def fit_lambda_matrices( metadata=factor_args.to_str_dict(), ) self.state.wait_for_everyone() - del lambda_factors, loader self.logger.info(f"Saved Lambda matrices at `{factors_output_dir}`.") + del num_data_processed, lambda_factors, loader + self._reset_memory() all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index bdefda1..c9ab6be 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -192,7 +192,6 @@ def fit_covariance_matrices_with_loader( mode=ModuleMode.COVARIANCE, release_memory=True, ) - release_memory() total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) @@ -233,6 +232,7 @@ def fit_covariance_matrices_with_loader( state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) + del batch, attention_mask, loss total_steps += 1 pbar.update(1) @@ -260,7 +260,6 @@ def fit_covariance_matrices_with_loader( if enable_amp: set_gradient_scale(model=model, gradient_scale=1.0) set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) - release_memory() state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index 23aedab..55abc97 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -392,7 +392,6 @@ def fit_lambda_matrices_with_loader( if eigen_factors is not None: for name in eigen_factors: set_factors(model=model, factor_name=name, factors=eigen_factors[name]) - release_memory() total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) @@ -420,6 +419,7 @@ def fit_lambda_matrices_with_loader( sample=not factor_args.use_empirical_fisher, ) scaler.scale(loss).backward() + del loss if factor_args.has_shared_parameters: finalize_iteration(model=model, tracked_module_names=tracked_module_names) @@ -459,6 +459,5 @@ def fit_lambda_matrices_with_loader( set_gradient_scale(model=model, gradient_scale=1.0) set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) state.wait_for_everyone() - release_memory() return num_data_processed, saved_factors diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 6ae5b95..a96f810 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -27,6 +27,7 @@ PRECONDITIONED_GRADIENT_TYPE, SELF_SCORE_VECTOR_NAME, ) +from kronfluence.utils.state import State class ModuleMode(str, BaseEnum): @@ -88,13 +89,14 @@ def __init__( self._constant: torch.Tensor = nn.Parameter( torch.zeros( 1, + dtype=self.original_module.weight.dtype, requires_grad=True, - dtype=torch.float16, ) ) self.current_mode = ModuleMode.DEFAULT self.factor_args = FactorArguments() if factor_args is None else factor_args self.score_args = ScoreArguments() if score_args is None else score_args + self.state = State() self.per_sample_gradient_process_fnc = per_sample_gradient_process_fnc self.einsum_expression = None @@ -134,7 +136,8 @@ def __init__( def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: """A forward pass of the tracked module. This should have identical behavior to that of the original module.""" - return self.original_module(inputs + self._constant, *args, **kwargs) + # return self.original_module(inputs + self._constant, *args, **kwargs) + return self.original_module(inputs, *args, **kwargs) + self._constant def prepare_storage(self, device: torch.device) -> None: """Performs any necessary operations on storage before computing any metrics.""" diff --git a/kronfluence/module/tracker/base.py b/kronfluence/module/tracker/base.py index 24181be..b6027a3 100644 --- a/kronfluence/module/tracker/base.py +++ b/kronfluence/module/tracker/base.py @@ -17,6 +17,7 @@ def __init__(self, module: nn.Module) -> None: """ self.module = module self.registered_hooks: List[RemovableHandle] = [] + self.cached_hooks: List[RemovableHandle] = [] self.cached_activations: Optional[Union[List[torch.Tensor]], torch.Tensor] = None self.cached_per_sample_gradient: Optional[torch.Tensor] = None @@ -32,6 +33,11 @@ def clear_all_cache(self) -> None: del self.cached_activations, self.cached_per_sample_gradient self.cached_activations, self.cached_per_sample_gradient = None, None + while self.cached_hooks: + handle = self.cached_hooks.pop() + handle.remove() + self.cached_hooks = [] + def _scale_output_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: """Scales the output gradient and convert to the target dtype. diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py index 6208278..917c226 100644 --- a/kronfluence/module/tracker/factor.py +++ b/kronfluence/module/tracker/factor.py @@ -82,16 +82,20 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N def register_hooks(self) -> None: """Sets up hooks to compute activation and gradient covariance matrices.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - # Computes and updates activation covariance during forward pass. - input_activation = inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype) - self._update_activation_covariance_matrix(input_activation=input_activation) - outputs.register_hook(backward_hook) + with torch.no_grad(): + # Computes and updates activation covariance during forward pass. + input_activation = ( + inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype, copy=True) + ) + self._update_activation_covariance_matrix(input_activation=input_activation) + self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: + handle = self.cached_hooks.pop() + handle.remove() # Computes and updates pseudo-gradient covariance during backward pass. original_dtype = output_gradient.dtype target_dtype = self.module.factor_args.gradient_covariance_dtype @@ -103,7 +107,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None: output_gradient = output_gradient * self.module.gradient_scale self._update_gradient_covariance_matrix(output_gradient=output_gradient) - self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + # self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) def exist(self) -> bool: """Checks if both activation and gradient covariance matrices are available.""" @@ -127,7 +132,6 @@ def synchronize(self, num_processes: int) -> None: def release_memory(self) -> None: """Clears all covariance matrices from memory.""" for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - del self.module.storage[covariance_factor_name] self.module.storage[covariance_factor_name] = None @@ -214,26 +218,27 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: def register_hooks(self) -> None: """Sets up hooks to compute lambda matrices.""" - @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - cached_activation = inputs[0].detach() - device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.factor_args.per_sample_gradient_dtype, - copy=True, - ) - - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation + with torch.no_grad(): + cached_activation = inputs[0].detach() + device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.factor_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation - outputs.register_hook( - shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + self.cached_hooks.append( + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) ) @torch.no_grad() @@ -241,6 +246,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None: if self.cached_activations is None: self._raise_cache_not_found_exception() + handle = self.cached_hooks.pop() + handle.remove() original_dtype = output_gradient.dtype target_dtype = self.module.factor_args.per_sample_gradient_dtype output_gradient = output_gradient.detach().to(dtype=target_dtype) @@ -258,6 +265,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None: @torch.no_grad() def shared_backward_hook(output_gradient: torch.Tensor) -> None: + handle = self.cached_hooks.pop() + handle.remove() original_dtype = output_gradient.dtype target_dtype = self.module.factor_args.per_sample_gradient_dtype output_gradient = output_gradient.detach().to(dtype=target_dtype)