Skip to content

Commit

Permalink
add debug code
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 5, 2024
1 parent 3dde1de commit 27aa722
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
25 changes: 25 additions & 0 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,19 @@ def fit_covariance_matrices(
)

self._reset_memory()


import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (
hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device(
"cuda"):
print(type(obj), obj.size())
except:
pass


start_time = get_time(state=self.state)
with self.profiler.profile("Fit Covariance"):
loader = self._get_dataloader(
Expand All @@ -323,6 +336,18 @@ def fit_covariance_matrices(
f"Fitted covariance matrices with {num_data_processed.item()} data points in "
f"{elapsed_time:.2f} seconds."
)
self._reset_memory()

print("Done")
import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (
hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device(
"cuda"):
print(type(obj), obj.size())
except:
pass

with self.profiler.profile("Save Covariance"):
if self.state.is_main_process:
Expand Down
22 changes: 0 additions & 22 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,6 @@ def fit_covariance_matrices_with_loader(
- Number of data points processed.
- Computed covariance matrices (nested dict: factor_name -> module_name -> tensor).
"""
import gc
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"):
print(type(obj), obj.size())
except:
pass

update_factor_args(model=model, factor_args=factor_args)
if tracked_module_names is None:
tracked_module_names = get_tracked_module_names(model=model)
Expand Down Expand Up @@ -243,13 +235,6 @@ def fit_covariance_matrices_with_loader(
total_steps += 1
pbar.update(1)

for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"):
print(type(obj), obj.size())
except:
pass

if state.use_distributed:
synchronize_modules(model=model, tracked_module_names=tracked_module_names)
num_data_processed = num_data_processed.to(device=state.device)
Expand All @@ -276,11 +261,4 @@ def fit_covariance_matrices_with_loader(
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()

for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)) and obj.device == torch.device("cuda"):
print(type(obj), obj.size())
except:
pass

return num_data_processed, saved_factors

0 comments on commit 27aa722

Please sign in to comment.