From 346d4fd84ca2c57488ca87b15203606d58546226 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 5 Jul 2024 02:03:25 -0400 Subject: [PATCH] Debug code to track memory --- kronfluence/factor/covariance.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index e71b550..5b216ef 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -183,6 +183,14 @@ def fit_covariance_matrices_with_loader( - Number of data points processed. - Computed covariance matrices (nested dict: factor_name -> module_name -> tensor). """ + import gc + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + print(type(obj), obj.size()) + except: + pass + update_factor_args(model=model, factor_args=factor_args) if tracked_module_names is None: tracked_module_names = get_tracked_module_names(model=model) @@ -235,6 +243,13 @@ def fit_covariance_matrices_with_loader( total_steps += 1 pbar.update(1) + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + print(type(obj), obj.size()) + except: + pass + if state.use_distributed: synchronize_modules(model=model, tracked_module_names=tracked_module_names) num_data_processed = num_data_processed.to(device=state.device) @@ -261,4 +276,11 @@ def fit_covariance_matrices_with_loader( set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) state.wait_for_everyone() + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + print(type(obj), obj.size()) + except: + pass + return num_data_processed, saved_factors