diff --git a/kronfluence/computer/score_computer.py b/kronfluence/computer/score_computer.py index 74699fb..d606124 100644 --- a/kronfluence/computer/score_computer.py +++ b/kronfluence/computer/score_computer.py @@ -450,6 +450,7 @@ def compute_pairwise_scores( ) self.state.wait_for_everyone() del scores, query_loader, train_loader + self._reset_memory() self.logger.info(f"Saved pairwise scores at {scores_output_dir}.") all_end_time = get_time(state=self.state) @@ -755,6 +756,7 @@ def compute_self_scores( ) self.state.wait_for_everyone() del scores, train_loader + self._reset_memory() self.logger.info(f"Saved self-influence scores at `{scores_output_dir}`.") all_end_time = get_time(state=self.state) diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index 2a993cb..65b79b1 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -192,6 +192,7 @@ def compute_pairwise_scores_with_loaders( factors=loaded_factors[name], ) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) + release_memory() total_scores_chunks: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {} total_query_batch_size = per_device_query_batch_size * state.num_processes diff --git a/kronfluence/utils/state.py b/kronfluence/utils/state.py index c145255..8f4da96 100644 --- a/kronfluence/utils/state.py +++ b/kronfluence/utils/state.py @@ -1,7 +1,7 @@ import contextlib import gc import os -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List import torch import torch.distributed as dist @@ -113,6 +113,18 @@ def release_memory() -> None: torch.cuda.empty_cache() +def get_active_tensors() -> List[torch.Tensor]: + # https://discuss.pytorch.org/t/how-to-debug-causes-of-gpu-memory-leaks/6741/3 + tensor_lst = [] + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + tensor_lst.append(type(obj), obj.size()) + except: + pass + return tensor_lst + + @contextlib.contextmanager def no_sync(model: nn.Module, state: State) -> Callable: """A context manager to temporarily disable gradient synchronization in distributed setting.