diff --git a/kronfluence/analyzer.py b/kronfluence/analyzer.py index 7f79c8e..615c1c0 100644 --- a/kronfluence/analyzer.py +++ b/kronfluence/analyzer.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Dict, Optional, Union @@ -45,8 +46,7 @@ def prepare_model( class Analyzer(FactorComputer, ScoreComputer): - """Handles the computation of factors (e.g., covariance and Lambda matrices for EKFAC) and - influence scores for a given PyTorch model.""" + """Handles the computation of factors (e.g., covariance matrices) and scores for a given PyTorch model.""" def __init__( self, @@ -83,7 +83,7 @@ def __init__( output_dir (str): Directory path for storing analysis results. Defaults to './influence_results'. disable_model_save (bool, optional): - If `True`, prevents saving the model's state_dict. Defaults to `True`. + If `True`, prevents saving the model's `state_dict`. Defaults to `True`. Raises: ValueError: @@ -100,8 +100,8 @@ def __init__( disable_tqdm=disable_tqdm, output_dir=output_dir, ) - self.logger.info(f"Initializing Computer with parameters: {locals()}") - self.logger.debug(f"Process state configuration:\n{repr(self.state)}") + self.logger.info(f"Initializing `Analyzer` with parameters: {locals()}") + self.logger.info(f"Process state configuration:\n{repr(self.state)}") # Save model parameters if necessary. if self.state.is_main_process and not disable_model_save: @@ -113,7 +113,7 @@ def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None: Args: dataloader_kwargs (DataLoaderKwargs): - The object containing arguments for DataLoader. + The object containing arguments for PyTorch DataLoader. """ self._dataloader_params = dataloader_kwargs @@ -121,7 +121,7 @@ def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None: def _save_model(self) -> None: """Saves the model to the output directory.""" model_save_path = self.output_dir / "model.safetensors" - extracted_model = extract_model_from_parallel(model=self.model, keep_fp32_wrapper=True) + extracted_model = extract_model_from_parallel(model=copy.deepcopy(self.model), keep_fp32_wrapper=True) if model_save_path.exists(): self.logger.info(f"Found existing saved model at `{model_save_path}`.") @@ -151,13 +151,13 @@ def fit_all_factors( factor_args: Optional[FactorArguments] = None, overwrite_output_dir: bool = False, ) -> None: - """Computes all necessary factors for the given factor strategy. + """Computes all necessary factors for the given strategy. Args: factors_name (str): Unique identifier for the factor, used for organizing results. dataset (data.Dataset): - Dataset used to fit all the factors. + Dataset used to fit all influence factors. per_device_batch_size (int, optional): Per-device batch size for factor fitting. If not specified, executable per-device batch size is automatically determined. @@ -168,7 +168,7 @@ def fit_all_factors( factor_args (FactorArguments, optional): Arguments for factor computation. Defaults to `FactorArguments` default values. overwrite_output_dir (bool, optional): - If `True`, overwrites existing factors with the same name. Defaults to `False`. + If `True`, overwrites existing factors with the same `factors_name`. Defaults to `False`. """ self.fit_covariance_matrices( factors_name=factors_name, @@ -211,7 +211,7 @@ def load_file(path: Union[str, Path]) -> Dict[str, torch.Tensor]: If the specified file does not exist. Note: - For more information on safetensors, see https://github.com/huggingface/safetensors. + For more information on `safetensors`, see https://github.com/huggingface/safetensors. """ if isinstance(path, str): path = Path(path).resolve() diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index b3af8ba..242688d 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -168,13 +168,6 @@ class ScoreArguments(Arguments): default=False, metadata={"help": "If `True`, offloads cached activations to CPU memory when computing per-sample gradients."}, ) - einsum_minimize_size: bool = field( - default=False, - metadata={ - "help": "If `True`, einsum operations find the contraction that minimizes the size of the " - "largest intermediate tensor." - }, - ) # Partition configuration # data_partitions: int = field( @@ -209,7 +202,7 @@ class ScoreArguments(Arguments): query_gradient_low_rank: Optional[int] = field( default=None, metadata={ - "help": "Rank for the low-rank approximation of the query gradient. " + "help": "Rank for the low-rank approximation of the query gradient (query batching). " "If `None`, no low-rank approximation is applied." }, ) @@ -248,7 +241,7 @@ class ScoreArguments(Arguments): ) per_sample_gradient_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for per-sample gradient computation."}, + metadata={"help": "Data type for query per-sample gradient computation."}, ) precondition_dtype: torch.dtype = field( default=torch.float32, @@ -260,8 +253,8 @@ class ScoreArguments(Arguments): ) def __post_init__(self) -> None: - if self.damping_factor is not None and self.damping_factor <= 0: - raise ValueError("`damping_factor` must be None or positive.") + if self.damping_factor is not None and self.damping_factor < 0: + raise ValueError("`damping_factor` must be `None` or positive.") if any(partition <= 0 for partition in [self.data_partitions, self.module_partitions]): raise ValueError("Both data and module partitions must be positive.") @@ -270,4 +263,4 @@ def __post_init__(self) -> None: raise ValueError("`query_gradient_accumulation_steps` must be positive.") if self.query_gradient_low_rank is not None and self.query_gradient_low_rank <= 0: - raise ValueError("`query_gradient_low_rank` must be None or positive.") + raise ValueError("`query_gradient_low_rank` must be `None` or positive.") diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index c9ab6be..35a7456 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -29,7 +29,7 @@ PARTITION_TYPE, ) from kronfluence.utils.logger import TQDM_BAR_FORMAT -from kronfluence.utils.state import State, no_sync, release_memory +from kronfluence.utils.state import State, no_sync def covariance_matrices_save_path( diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index 47feb23..90dbdd2 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -221,8 +221,6 @@ def perform_eigendecomposition( pbar.update(1) - release_memory() - return eigen_factors @@ -391,7 +389,7 @@ def fit_lambda_matrices_with_loader( ) if eigen_factors is not None: for name in eigen_factors: - set_factors(model=model, factor_name=name, factors=eigen_factors[name]) + set_factors(model=model, factor_name=name, factors=eigen_factors[name], clone=True) total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 062a561..1d30cec 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from einconv.utils import get_conv_paddings from einops import rearrange, reduce -from opt_einsum import DynamicProgramming, contract_expression, contract +from opt_einsum import DynamicProgramming, contract, contract_expression from torch import nn from torch.nn.modules.utils import _pair @@ -116,7 +116,6 @@ def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torc tensor=input_activation, pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) - if self.original_module.bias is not None: input_activation = torch.cat( [ @@ -145,7 +144,6 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten tensor=input_activation, pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) - if self.original_module.bias is not None: input_activation = torch.cat( [ @@ -160,7 +158,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - summed_gradient = contract("bci,bco->io", output_gradient, input_activation) + summed_gradient = contract("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0) return summed_gradient.view((1, *summed_gradient.size())) def compute_per_sample_gradient( @@ -195,9 +193,7 @@ def compute_pairwise_score( right_mat.shape, output_gradient.shape, input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize="size"), ) return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) @@ -207,9 +203,7 @@ def compute_pairwise_score( preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize="flops"), ) return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) @@ -225,8 +219,6 @@ def compute_self_measurement_score( preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize="flops"), ) return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index bb2b739..d8d4a51 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -62,7 +62,7 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: input_activation = self._flatten_input_activation(input_activation=input_activation) - summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(0) + summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(dim=0) return summed_gradient def compute_per_sample_gradient( @@ -93,25 +93,23 @@ def compute_pairwise_score( right_mat.shape, output_gradient.shape, input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize="size"), ) return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) if self.einsum_expression is None: if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: expr = "qio,bti,bto->qbt" + minimize = "size" else: expr = "qio,b...i,b...o->qb" + minimize = "flops" self.einsum_expression = contract_expression( expr, preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize=minimize), ) return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) @@ -125,8 +123,6 @@ def compute_self_measurement_score( preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize="flops"), ) return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 047d28d..b05e023 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -54,7 +54,14 @@ class TrackedModule(nn.Module): SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {} def __init_subclass__(cls, module_type: Type[nn.Module] = None, **kwargs: Any) -> None: - """Automatically registers subclasses as supported modules.""" + """Automatically registers subclasses as supported modules. + + Args: + module_type (Type[nn.Module], optional): + The type of module this subclass supports. + **kwargs: + Additional keyword arguments. + """ super().__init_subclass__(**kwargs) if module_type is not None: cls.SUPPORTED_MODULES[module_type] = cls @@ -75,17 +82,16 @@ def __init__( original_module (nn.Module): The original module to be wrapped. factor_args (FactorArguments, optional): - Arguments for computing influence factors. + Arguments for computing factors. score_args (ScoreArguments, optional): Arguments for computing influence scores. per_sample_gradient_process_fnc (Callable, optional): - Function to post-process per-sample gradients. + Optional function to post-process per-sample gradients. """ super().__init__() self.name = name self.original_module = original_module - # A way to avoid Autograd computing the gradient with respect to the model parameters. self._constant: torch.Tensor = nn.Parameter( torch.zeros( 1, @@ -96,9 +102,7 @@ def __init__( self.current_mode = ModuleMode.DEFAULT self.factor_args = FactorArguments() if factor_args is None else factor_args self.score_args = ScoreArguments() if score_args is None else score_args - self.state = State() self.per_sample_gradient_process_fnc = per_sample_gradient_process_fnc - self.einsum_expression = None self._trackers = { ModuleMode.DEFAULT: BaseTracker(self), @@ -114,6 +118,13 @@ def __init__( self.attention_mask: Optional[torch.Tensor] = None self.gradient_scale: float = 1.0 self.storage: Dict[str, Optional[Union[torch.Tensor, PRECONDITIONED_GRADIENT_TYPE]]] = {} + self.state: State = State() + self.einsum_expression: Optional[Callable] = None + + self._initialize_storage() + + def _initialize_storage(self) -> None: + """Initializes trackers for different module modes.""" # Storage for activation and pseudo-gradient covariance matrices # for covariance_factor_name in COVARIANCE_FACTOR_NAMES: @@ -142,7 +153,7 @@ def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: return outputs + self._constant def prepare_storage(self, device: torch.device) -> None: - """Performs any necessary operations on storage before computing any metrics.""" + """Performs any necessary operations on storage before computing influence scores.""" FactorConfig.CONFIGS[self.factor_args.strategy].prepare( storage=self.storage, score_args=self.score_args, diff --git a/kronfluence/module/tracker/base.py b/kronfluence/module/tracker/base.py index b6027a3..1e0fc3f 100644 --- a/kronfluence/module/tracker/base.py +++ b/kronfluence/module/tracker/base.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union import torch -import torch.nn as nn +from torch import nn from torch.utils.hooks import RemovableHandle @@ -23,36 +23,45 @@ def __init__(self, module: nn.Module) -> None: def release_hooks(self) -> None: """Removes all registered hooks.""" + self.clear_all_cache() while self.registered_hooks: handle = self.registered_hooks.pop() handle.remove() self.registered_hooks = [] def clear_all_cache(self) -> None: - """Clears all cached data from memory.""" + """Clears all cached data and removes cached hooks.""" del self.cached_activations, self.cached_per_sample_gradient self.cached_activations, self.cached_per_sample_gradient = None, None - while self.cached_hooks: handle = self.cached_hooks.pop() handle.remove() self.cached_hooks = [] - def _scale_output_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: - """Scales the output gradient and convert to the target dtype. + def _raise_cache_not_found_exception(self) -> None: + """Raises an exception when cached activations are not found.""" + raise RuntimeError( + f"Module '{self.module.name}' has no cached activations. This can occur if:\n" + f"1. The module was not used during the forward pass, or\n" + f"2. The module was encountered multiple times in the forward pass.\n" + f"For case 2, set 'has_shared_parameters=True' to enable parameter sharing." + ) + + def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + """Preprocesses the output gradient. Args: output_gradient (torch.Tensor): - The output gradient to scale. + The original output gradient. target_dtype (torch.dtype): - The desired dtype for the output. + The desired data type for the gradient tensor. Returns: torch.Tensor: - The scaled gradient in the target dtype. + The preprocessed gradient. """ original_dtype = output_gradient.dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) + output_gradient = output_gradient.to(dtype=target_dtype) if self.module.gradient_scale != 1.0: if original_dtype != target_dtype: output_gradient.mul_(self.module.gradient_scale) @@ -60,15 +69,6 @@ def _scale_output_gradient(self, output_gradient: torch.Tensor, target_dtype: to output_gradient = output_gradient * self.module.gradient_scale return output_gradient - def _raise_cache_not_found_exception(self) -> None: - """Raises an exception when cached activations are not found.""" - raise RuntimeError( - f"Module '{self.module.name}' has no cached activations. This can occur if:\n" - f"1. The module was not used during the forward pass, or\n" - f"2. The module was encountered multiple times in the forward pass.\n" - f"For case 2, set 'has_shared_parameters=True' to enable parameter sharing." - ) - def register_hooks(self) -> None: """Registers hooks for the module.""" @@ -92,7 +92,7 @@ def truncate(self, keep_size: int) -> None: Args: keep_size (int): - The number of dimension to keep. + The number of dimensions to keep. """ def accumulate_iterations(self) -> None: diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py index 6d0f969..395b3ac 100644 --- a/kronfluence/module/tracker/factor.py +++ b/kronfluence/module/tracker/factor.py @@ -82,33 +82,28 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N def register_hooks(self) -> None: """Sets up hooks to compute activation and gradient covariance matrices.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - input_activation = ( - inputs[0] - .detach() - .to( - dtype=self.module.factor_args.activation_covariance_dtype, - copy=self.module.attention_mask is not None, - ) + input_activation = ( + inputs[0] + .detach() + .to( + dtype=self.module.factor_args.activation_covariance_dtype, + copy=self.module.attention_mask is not None, ) - # Computes and updates activation covariance during forward pass. - self._update_activation_covariance_matrix(input_activation=input_activation) + ) + # Computes and updates activation covariance during forward pass. + self._update_activation_covariance_matrix(input_activation=input_activation) self.cached_hooks.append(outputs.register_hook(backward_hook)) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.factor_args.gradient_covariance_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient, target_dtype=self.module.factor_args.gradient_covariance_dtype + ) # Computes and updates pseudo-gradient covariance during backward pass. self._update_gradient_covariance_matrix(output_gradient=output_gradient) @@ -222,23 +217,22 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: def register_hooks(self) -> None: """Sets up hooks to compute lambda matrices.""" + @torch.no_grad() def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: del module - with torch.no_grad(): - cached_activation = inputs[0].detach() - device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device - cached_activation = cached_activation.to( - device=device, - dtype=self.module.factor_args.per_sample_gradient_dtype, - copy=True, - ) - if self.module.factor_args.has_shared_parameters: - if self.cached_activations is None: - self.cached_activations = [] - self.cached_activations.append(cached_activation) - else: - self.cached_activations = cached_activation - + cached_activation = inputs[0].detach() + device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.factor_args.per_sample_gradient_dtype, + copy=True, + ) + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation self.cached_hooks.append( outputs.register_hook( shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook @@ -251,14 +245,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._raise_cache_not_found_exception() handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.factor_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient, target_dtype=self.module.factor_args.per_sample_gradient_dtype + ) per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=self.cached_activations.to(device=output_gradient.device), output_gradient=output_gradient, @@ -271,14 +260,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: def shared_backward_hook(output_gradient: torch.Tensor) -> None: handle = self.cached_hooks.pop() handle.remove() - original_dtype = output_gradient.dtype - target_dtype = self.module.factor_args.per_sample_gradient_dtype - output_gradient = output_gradient.detach().to(dtype=target_dtype) - if self.module.gradient_scale != 1.0: - if original_dtype != target_dtype: - output_gradient.mul_(self.module.gradient_scale) - else: - output_gradient = output_gradient * self.module.gradient_scale + output_gradient = self._preprocess_gradient( + output_gradient=output_gradient, target_dtype=self.module.factor_args.per_sample_gradient_dtype + ) cached_activation = self.cached_activations.pop() per_sample_gradient = self.module.compute_per_sample_gradient( input_activation=cached_activation.to(device=output_gradient.device), diff --git a/kronfluence/module/tracker/pairwise_score.py b/kronfluence/module/tracker/pairwise_score.py index a8c400c..94696ee 100644 --- a/kronfluence/module/tracker/pairwise_score.py +++ b/kronfluence/module/tracker/pairwise_score.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from opt_einsum import DynamicProgramming, contract_expression +from opt_einsum import DynamicProgramming, contract, contract_expression from kronfluence.module.tracker.base import BaseTracker from kronfluence.utils.constants import ( @@ -32,13 +32,11 @@ def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tenso right_mat.shape, per_sample_gradient.shape, left_mat.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.module.score_args.einsum_minimize_size else "flops" - ), + optimize=DynamicProgramming(search_outer=True, minimize="size"), ) scores = self.module.einsum_expression(right_mat, per_sample_gradient, left_mat) else: - scores = torch.einsum( + scores = contract( "qio,tio->qt", self.module.storage[precondition_name], per_sample_gradient, diff --git a/kronfluence/module/tracker/self_score.py b/kronfluence/module/tracker/self_score.py index a08abd6..248c0e4 100644 --- a/kronfluence/module/tracker/self_score.py +++ b/kronfluence/module/tracker/self_score.py @@ -1,15 +1,11 @@ -from typing import List, Tuple +from typing import Tuple import torch import torch.nn as nn -from opt_einsum import DynamicProgramming, contract_expression from kronfluence.factor.config import FactorConfig from kronfluence.module.tracker.base import BaseTracker from kronfluence.utils.constants import ( - ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, - AGGREGATED_GRADIENT_NAME, - PAIRWISE_SCORE_MATRIX_NAME, PRECONDITIONED_GRADIENT_NAME, SELF_SCORE_VECTOR_NAME, ) @@ -77,9 +73,11 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch. else: self.cached_activations = cached_activation - self.cached_hooks.append(outputs.register_hook( - shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook - )) + self.cached_hooks.append( + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + ) @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: @@ -163,9 +161,7 @@ def _compute_self_measurement_score_with_gradient(self, per_sample_gradient: tor per_sample_gradient (torch.Tensor): The per-sample-gradient tensor for the given batch. """ - scores = per_sample_gradient.mul_( - self.module.storage[PRECONDITIONED_GRADIENT_NAME] - ).sum(dim=(1, 2)) + scores = per_sample_gradient.mul_(self.module.storage[PRECONDITIONED_GRADIENT_NAME]).sum(dim=(1, 2)) self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None if self.module.storage[SELF_SCORE_VECTOR_NAME] is None: self.module.storage[SELF_SCORE_VECTOR_NAME] = scores @@ -240,20 +236,6 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self.clear_all_cache() self._compute_self_measurement_score_with_gradient(per_sample_gradient=per_sample_gradient) - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient( - output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype - ) - cached_activation = self.cached_activations.pop() - per_sample_gradient = self.module.compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient, - ) - if self.cached_per_sample_gradient is None: - self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self.cached_per_sample_gradient.add_(per_sample_gradient) - self.registered_hooks.append(self.module.register_forward_hook(forward_hook)) @torch.no_grad() diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py index e401a12..d8101a5 100644 --- a/kronfluence/score/dot_product.py +++ b/kronfluence/score/dot_product.py @@ -57,20 +57,19 @@ def compute_dot_products_with_loader( ) release_memory() + cached_module_lst = [] + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + cached_module_lst.append(module) + dataset_size = len(train_loader.dataset) score_chunks: Dict[str, List[torch.Tensor]] = {} if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name] = [] + for module in cached_module_lst: + score_chunks[module.name] = [] else: score_chunks[ALL_MODULE_NAME] = [] - cached_module_lst = [] - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - cached_module_lst.append(module) - total_steps = 0 enable_amp = score_args.amp_dtype is not None diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index 65b79b1..2a993cb 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -192,7 +192,6 @@ def compute_pairwise_scores_with_loaders( factors=loaded_factors[name], ) prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) - release_memory() total_scores_chunks: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {} total_query_batch_size = per_device_query_batch_size * state.num_processes diff --git a/kronfluence/task.py b/kronfluence/task.py index 09a1736..33472a3 100644 --- a/kronfluence/task.py +++ b/kronfluence/task.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union, Dict import torch from torch import nn @@ -34,14 +34,14 @@ def compute_train_loss( The PyTorch model used for loss computation. sample (bool): Indicates whether to sample from the model's outputs or to use the actual targets from the - batch. Defaults to `False`. The case where `sample` is set to `True` must be implemented to + batch. Defaults to `False`. The case where `sample=True` must be implemented to approximate the true Fisher. Returns: torch.Tensor: The computed loss as a scalar tensor. """ - raise NotImplementedError("Subclasses must implement the `compute_train_loss` method.") + raise NotImplementedError(f"{self.__class__.__name__} must implement the `compute_train_loss` method.") @abstractmethod def compute_measurement( @@ -64,22 +64,22 @@ def compute_measurement( torch.Tensor: The computed measurable quantity as a tensor. """ - raise NotImplementedError("Subclasses must implement the `compute_measurement` method.") + raise NotImplementedError(f"{self.__class__.__name__} must implement the `compute_measurement` method.") def get_influence_tracked_modules(self) -> Optional[List[str]]: - """Specifies which modules should be tracked for influence score computations. + """Specifies which modules should be tracked for influence factor and score computations. Override this method in subclasses to return a list of specific module names if influence functions should only be computed for a subset of the model. Returns: Optional[List[str]]: - A list of module names to compute influence functions for, or None to compute for - all applicable modules (e.g., nn.Linear, nn.Conv2d). + A list of module names to compute influence functions for, or `None` to compute for + all applicable modules (e.g., `nn.Linear` and `nn.Conv2d`). """ def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: - """Returns attention masks for padded sequences in a batch. + """Gets attention masks for padded sequences in a batch. This method is typically used for models or datasets that require masking, such as transformer-based architectures. For more information, see: https://huggingface.co/docs/transformers/en/glossary#attention-mask. diff --git a/kronfluence/utils/common/score_arguments.py b/kronfluence/utils/common/score_arguments.py index 66b7aa9..40675f6 100644 --- a/kronfluence/utils/common/score_arguments.py +++ b/kronfluence/utils/common/score_arguments.py @@ -11,7 +11,6 @@ def default_score_arguments( """Creates default score arguments""" score_args = ScoreArguments(damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank) if score_args.query_gradient_low_rank is not None: - score_args.einsum_minimize_size = True score_args.query_gradient_accumulation_steps = 10 return score_args @@ -67,7 +66,6 @@ def reduce_memory_score_arguments( score_args = all_low_precision_score_arguments( damping_factor=damping_factor, query_gradient_low_rank=query_gradient_low_rank, dtype=dtype ) - score_args.einsum_minimize_size = True score_args.offload_activations_to_cpu = True return score_args diff --git a/kronfluence/utils/dataset.py b/kronfluence/utils/dataset.py index e9836ac..9f33b22 100644 --- a/kronfluence/utils/dataset.py +++ b/kronfluence/utils/dataset.py @@ -110,7 +110,7 @@ class DistributedEvalSampler(Sampler[T_co]): def __init__( # pylint: disable=super-init-not-called self, - dataset: torch.utils.data.Dataset, + dataset: data.Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, diff --git a/kronfluence/utils/logger.py b/kronfluence/utils/logger.py index 5c90d5e..2ca1a0f 100644 --- a/kronfluence/utils/logger.py +++ b/kronfluence/utils/logger.py @@ -278,7 +278,12 @@ def summary(self) -> str: # Timing utilities copied from: # https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/pytorch_utils.py. def _get_monotonic_time() -> float: - """Gets the monotonic time after the CUDA synchronization if necessary.""" + """Gets the time after the CUDA synchronization. + + Returns: + float: + The current time. + """ if torch.cuda.is_available() and torch.cuda.is_initialized(): torch.cuda.synchronize() return time.monotonic() @@ -286,7 +291,16 @@ def _get_monotonic_time() -> float: @torch.no_grad() def get_time(state: State) -> float: - """Gets the current time after synchronizing with other devices.""" + """Gets the current time after synchronizing with other devices. + + Args: + state (State): + The current process's information (e.g., device being used). + + Returns: + float: + The current time. + """ if not state.use_distributed: if torch.cuda.is_available() and torch.cuda.is_initialized(): torch.cuda.synchronize() diff --git a/kronfluence/utils/state.py b/kronfluence/utils/state.py index 8f4da96..436dea0 100644 --- a/kronfluence/utils/state.py +++ b/kronfluence/utils/state.py @@ -7,7 +7,6 @@ import torch.distributed as dist from accelerate.state import SharedDict from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel class State: @@ -50,7 +49,12 @@ def __init__(self, cpu: bool = False) -> None: self.device = torch.device("cpu") if self.cpu else self.default_device def __repr__(self) -> str: - """Provides a string representation of the State instance.""" + """Provides a string representation of the `State` instance. + + Returns: + str: + A formatted string containing process and device information. + """ return ( f"Num processes: {self.num_processes}\n" f"Process index: {self.process_index}\n" @@ -65,7 +69,7 @@ def _reset_state() -> None: @property def initialized(self) -> bool: - """Checks if the State has been initialized.""" + """Checks if the `State` has been initialized.""" return self._shared_state != {} @property @@ -99,28 +103,40 @@ def wait_for_everyone(self) -> None: @property def default_device(self) -> torch.device: - """Determines the default device (CUDA if available, otherwise CPU).""" + """Determines the default device (CUDA if available, otherwise CPU). + + Returns: + torch.device: + The default device. + """ if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def release_memory() -> None: - """Releases unused memory. This function calls Python's garbage collector and empties CUDA cache - if CUDA is available.""" + """Releases unused memory. + + This function calls Python's garbage collector and empties CUDA cache if CUDA is available. + """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def get_active_tensors() -> List[torch.Tensor]: - # https://discuss.pytorch.org/t/how-to-debug-causes-of-gpu-memory-leaks/6741/3 + """Gets a list of active tensors in memory. + + Returns: + List[torch.Tensor]: + A list of tuples containing tensor type and size. + """ tensor_lst = [] for obj in gc.get_objects(): try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)): tensor_lst.append(type(obj), obj.size()) - except: + except: # pylint: disable=bare-except pass return tensor_lst @@ -144,7 +160,7 @@ def no_sync(model: nn.Module, state: State) -> Callable: """ context = contextlib.nullcontext - if state.use_distributed and not isinstance(model, FullyShardedDataParallel): + if state.use_distributed: context = getattr(model, "no_sync", context) with context(): diff --git a/tests/modules/test_matmul.py b/tests/modules/test_matmul.py index e69de29..42b38a4 100644 --- a/tests/modules/test_matmul.py +++ b/tests/modules/test_matmul.py @@ -0,0 +1,182 @@ +import opt_einsum +import pytest +import torch +from accelerate.utils import set_seed +from opt_einsum import DynamicProgramming + + +def test_query_gradient_svd( + seed: int = 0, +) -> None: + input_dim = 2048 + output_dim = 1024 + batch_dim = 16 + set_seed(seed) + + gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) + + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + assert torch.allclose(gradient, U @ torch.diag_embed(S) @ V, atol=1e-5, rtol=1e-3) + + rank = 32 + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + assert torch.bmm(left, right).shape == gradient.shape + + rank = input_dim + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + assert torch.allclose(torch.bmm(left, right), gradient, atol=1e-5, rtol=1e-3) + + rank = 32 + lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank), dtype=torch.float64) + lr_gradient2 = torch.rand(size=(batch_dim, rank, input_dim), dtype=torch.float64) + gradient = torch.bmm(lr_gradient1, lr_gradient2) + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + assert torch.allclose(torch.bmm(left_mat, right_mat), gradient, atol=1e-5, rtol=1e-3) + + query_batch_dim = 32 + new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) + score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) + + lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + assert torch.allclose(score, lr_score) + + lr_score_reconst_matmul = torch.matmul( + torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() + ) + assert torch.allclose(score, lr_score_reconst_matmul) + + # These should be able to avoid explicit reconstruction. This should be used when input_dim > output_dim. + intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) + final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) + assert torch.allclose(score, final) + print("Option 1") + print(intermediate.numel()) + + # This should be used when output_dim > input_dim. + intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) + final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) + assert torch.allclose(score, final2) + print("Option 2") + print(intermediate2.numel()) + + print("Reconstruction") + print((torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel()) + path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + print(path) + + +@pytest.mark.parametrize("input_dim", [256, 512]) +@pytest.mark.parametrize("output_dim", [512, 1024]) +@pytest.mark.parametrize("batch_dim", [8, 16]) +@pytest.mark.parametrize("qbatch_dim", [8, 16]) +@pytest.mark.parametrize("rank", [32]) +@pytest.mark.parametrize("seed", [0]) +def test_query_gradient_svd_reconst( + input_dim: int, + output_dim: int, + batch_dim: int, + qbatch_dim: int, + rank: int, + seed: int, +) -> None: + set_seed(seed) + + lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank + 50), dtype=torch.float64) + lr_gradient2 = torch.rand(size=(batch_dim, rank + 50, input_dim), dtype=torch.float64) + gradient = torch.bmm(lr_gradient1, lr_gradient2) + U, S, V = torch.linalg.svd( + gradient.contiguous(), + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + V_k = V[:, :rank, :].clone() + left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() + new_gradient = torch.rand(size=(qbatch_dim, output_dim, input_dim), dtype=torch.float64) + + lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + lr_score_reconst_matmul = torch.matmul( + torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() + ) + assert torch.allclose(lr_score, lr_score_reconst_matmul) + + # This should be used when input_dim > output_dim. + intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) + final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) + assert torch.allclose(lr_score, final) + print("Option 1") + print(intermediate.numel()) + + # This should be used when output_dim > input_dim. + intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) + final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) + assert torch.allclose(lr_score, final2) + print("Option 2") + print(intermediate2.numel()) + + print("Reconstruction") + reconst_numel = (torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel() + print(reconst_numel) + path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) + print(path) + + if new_gradient.size(0) * right_mat.size(0) * rank * min((right_mat.size(2), left_mat.size(1))) > right_mat.size( + 0 + ) * right_mat.size(2) * left_mat.size(1): + assert intermediate2.numel() > reconst_numel and intermediate.numel() > reconst_numel + elif output_dim >= input_dim: + assert intermediate2.numel() <= reconst_numel + else: + assert intermediate.numel() <= reconst_numel + + +def test_compute_score_matmul( + seed: int = 0, +) -> None: + input_dim = 4096 + output_dim = 100 + token_dim = 1 + batch_dim = 1024 + query_batch_dim = 2 + set_seed(seed) + + input_activation = torch.rand(size=(batch_dim, token_dim, input_dim), dtype=torch.float64) + output_gradient = torch.rand(size=(batch_dim, token_dim, output_dim), dtype=torch.float64) + gradient = opt_einsum.contract("b...i,b...o->bio", output_gradient, input_activation) + new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) + + score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) + path = opt_einsum.contract_path("toi,qoi->tq", gradient, new_gradient) + print(path) + + unsqueeze_score = opt_einsum.contract("t...,q...->tq", gradient, new_gradient) + assert torch.allclose(score, unsqueeze_score) + + path = opt_einsum.contract_path( + "bti,bto,qio->qb", + output_gradient, + input_activation, + new_gradient, + optimize=DynamicProgramming(search_outer=True, minimize="flops"), + ) + print(path) diff --git a/tests/scores/test_pairwise_scores.py b/tests/scores/test_pairwise_scores.py index a0a85db..85f0709 100644 --- a/tests/scores/test_pairwise_scores.py +++ b/tests/scores/test_pairwise_scores.py @@ -27,12 +27,12 @@ "test_name", [ "mlp", - "repeated_mlp", - "conv", - "bert", - "roberta", - "gpt", - "gpt_checkpoint", + # "repeated_mlp", + # "conv", + # "bert", + # "roberta", + # "gpt", + # "gpt_checkpoint", ], ) @pytest.mark.parametrize("score_dtype", [torch.float32])