Skip to content

Commit

Permalink
Debug code to track memory
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 5, 2024
1 parent 8573276 commit 346d4fd
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def fit_covariance_matrices_with_loader(
- Number of data points processed.
- Computed covariance matrices (nested dict: factor_name -> module_name -> tensor).
"""
import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
print(type(obj), obj.size())
except:
pass

update_factor_args(model=model, factor_args=factor_args)
if tracked_module_names is None:
tracked_module_names = get_tracked_module_names(model=model)
Expand Down Expand Up @@ -235,6 +243,13 @@ def fit_covariance_matrices_with_loader(
total_steps += 1
pbar.update(1)

for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
print(type(obj), obj.size())
except:
pass

if state.use_distributed:
synchronize_modules(model=model, tracked_module_names=tracked_module_names)
num_data_processed = num_data_processed.to(device=state.device)
Expand All @@ -261,4 +276,11 @@ def fit_covariance_matrices_with_loader(
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()

for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
print(type(obj), obj.size())
except:
pass

return num_data_processed, saved_factors

0 comments on commit 346d4fd

Please sign in to comment.