diff --git a/.gitignore b/.gitignore index 1e2ced7..debad09 100644 --- a/.gitignore +++ b/.gitignore @@ -165,6 +165,7 @@ cython_debug/ # Checkpoints and influence outputs checkpoints/ analyses/ +influence_results/ data/ *.pth *.pt \ No newline at end of file diff --git a/kronfluence/computer/score_computer.py b/kronfluence/computer/score_computer.py index f87be0b..2b8fc43 100644 --- a/kronfluence/computer/score_computer.py +++ b/kronfluence/computer/score_computer.py @@ -516,7 +516,7 @@ def executable_batch_size_func(batch_size: int) -> None: score_args=score_args, factor_args=factor_args, tracked_module_names=tracked_modules_name, - disable_tqdm=True + disable_tqdm=True, ) per_device_batch_size = find_executable_batch_size( diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index eed783f..b29d38e 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -13,13 +13,14 @@ from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( + get_tracked_module_names, load_factors, remove_gradient_scale, set_attention_mask, set_gradient_scale, set_mode, synchronize_covariance_matrices, - update_factor_args, get_tracked_module_names, + update_factor_args, ) from kronfluence.task import Task from kronfluence.utils.constants import ( diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index 0bf5b4d..6bedbaa 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -21,7 +21,8 @@ set_gradient_scale, set_mode, synchronize_lambda_matrices, - update_factor_args, update_aggregated_lambda_matrices, + update_aggregated_lambda_matrices, + update_factor_args, ) from kronfluence.task import Task from kronfluence.utils.constants import ( diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index f044d7e..a0255d8 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -877,9 +877,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # The preconditioning factors need to be loaded to appropriate device as they will be # used at each iteration. - if not self._storge_at_current_device: - self._move_storage_to_device(target_device=per_sample_gradient.device) - self._storge_at_current_device = True + self._move_storage_to_device(target_device=per_sample_gradient.device) if self._cached_per_sample_gradient is None: self._cached_per_sample_gradient = per_sample_gradient @@ -941,9 +939,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # The preconditioning factors need to be loaded to appropriate device as they will be # used at each iteration. - if not self._storge_at_current_device: - self._move_storage_to_device(target_device=per_sample_gradient.device) - self._storge_at_current_device = True + self._move_storage_to_device(target_device=per_sample_gradient.device) if self._cached_per_sample_gradient is None: self._cached_per_sample_gradient = per_sample_gradient @@ -996,4 +992,3 @@ def release_scores(self) -> None: self._cached_activations = [] del self._cached_per_sample_gradient self._cached_per_sample_gradient = None - self._storge_at_current_device = False diff --git a/kronfluence/utils/dataset.py b/kronfluence/utils/dataset.py index 9b82744..920cf55 100644 --- a/kronfluence/utils/dataset.py +++ b/kronfluence/utils/dataset.py @@ -17,7 +17,7 @@ @dataclass class DataLoaderKwargs(KwargsHandler): """The object used to customize `DataLoader`. Please refer to https://pytorch.org/docs/stable/data.html for - detailed information of each argument. The default arguments are copied from PyTorch version 2.2. + detailed information of each argument. The default arguments are copied from PyTorch version 2.3. """ num_workers: int = 0 @@ -115,7 +115,7 @@ class DistributedSamplerWithStack(Sampler[T_co]): """DistributedSampleWithStack is different from `DistributedSampler`. Instead of subsampling, it stacks the dataset. For example, when `num_replicas` is 3, and the dataset of [0, ..., 9] is given, the first, second, and third rank should have [0, 1, 2], [3, 4, 5], and [6, 7, 8], respectively. However, - it still adds extra samples to make the dataset evenly divisible. + it still adds extra samples to make the dataset evenly divisible (different from DistributedEvalSampler). """ def __init__( # pylint: disable=super-init-not-called diff --git a/kronfluence/utils/logger.py b/kronfluence/utils/logger.py index fba4b88..04af511 100644 --- a/kronfluence/utils/logger.py +++ b/kronfluence/utils/logger.py @@ -143,7 +143,7 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str class PassThroughProfiler(Profiler): - """A pass through Profiler objective that does not record timing for the profiler.""" + """A pass through Profiler objective that does not record timing.""" def start(self, action_name: str) -> None: """Defines how to start recording an action.""" @@ -161,6 +161,8 @@ def summary(self) -> str: class TorchProfiler(Profiler): """A PyTorch Profiler objective that provides detailed profiling information: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html. + + This is useful for low-level profiling in PyTorch, and is not used by default. """ def __init__(self, state: State) -> None: @@ -174,7 +176,7 @@ def start(self, action_name: str) -> None: """Defines how to start recording an action.""" if action_name in self.current_actions: raise ValueError(f"Attempted to start {action_name} which has already started.") - # Set dummy value, since only used to track duplicate actions + # Set dummy value, since only used to track duplicate actions. self.current_actions[action_name] = 0.0 self.actions.append(action_name) self._torch_prof.start() diff --git a/kronfluence/utils/model.py b/kronfluence/utils/model.py index 2731690..1a71a79 100644 --- a/kronfluence/utils/model.py +++ b/kronfluence/utils/model.py @@ -15,22 +15,26 @@ def apply_ddp( - model: torch.nn.Module, + model: nn.Module, local_rank: int, rank: int, world_size: int, ) -> DistributedDataParallel: - """ - Applies DistributedDataParallel (DDP) to the given model. + """Applies DistributedDataParallel (DDP) to the given model. Args: - model (torch.nn.Module): The model to apply DDP to. - local_rank (int): The local rank of the current process. - rank (int): The rank of the current process. - world_size (int): The total number of processes. + model (nn.Module): + The model for which DDP will be applied. + local_rank (int): + The local rank of the current process. + rank (int): + The rank of the current process. + world_size (int): + The total number of processes. Returns: - DistributedDataParallel: The model wrapped with DDP. + DistributedDataParallel: + The model wrapped with DDP. """ dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device(f"cuda:{local_rank}") @@ -48,7 +52,7 @@ def apply_ddp( def apply_fsdp( - model: torch.nn.Module, + model: nn.Module, local_rank: int, rank: int, world_size: int, @@ -57,27 +61,31 @@ def apply_fsdp( is_transformer: bool = False, layer_to_wrap: Optional[nn.Module] = None, ) -> FSDP: - """ - Applies FullyShardedDataParallel (FSDP) to the given model. + """Applies FullyShardedDataParallel (FSDP) to the given model. Args: - model (torch.nn.Module): The model to apply FSDP to. - local_rank (int): The local rank of the current process. - rank (int): The rank of the current process. - world_size (int): The total number of processes. - sharding_strategy (str): The sharding strategy to use. - Defaults to "FULL_SHARD". - cpu_offload (bool): Whether to offload parameters to CPU. Check - https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.CPUOffload. - Defaults to True. - is_transformer (bool): Whether the model is a transformer model. - Defaults to False. - layer_to_wrap (nn.Module, optional): The specific layer to wrap - for transformer models. Required if `is_transformer` is True. + model (nn.Module): + The model for which FSDP will be applied. + local_rank (int): + The local rank of the current process. + rank (int): + The rank of the current process. + world_size (int): + The total number of processes. + sharding_strategy (str): + The sharding strategy to use. Defaults to "FULL_SHARD". + cpu_offload (bool): + Whether to offload parameters to CPU. Check + https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.CPUOffload. Defaults to True. + is_transformer (bool): + Whether the model is a transformer model. Defaults to False. + layer_to_wrap (nn.Module, optional): + The specific layer to wrap for transformer models. Required if `is_transformer` is True. Defaults to None. Returns: - FSDP: The model wrapped with FSDP. + FullyShardedDataParallel: + The model wrapped with FSDP. """ dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device(f"cuda:{local_rank}")