diff --git a/kronfluence/analyzer.py b/kronfluence/analyzer.py index f3582f7..293a7d0 100644 --- a/kronfluence/analyzer.py +++ b/kronfluence/analyzer.py @@ -20,33 +20,33 @@ def prepare_model( model: nn.Module, task: Task, ) -> nn.Module: - """Prepares the model before passing it to `Analyzer`. This function sets `param.requires_grad = False` - for all modules and installs `TrackedModule` to supported modules. This `TrackedModule` keeps track of relevant - statistics needed to compute influence scores. + """Prepares the model for analysis by setting all parameters and buffers to non-trainable + and installing `TrackedModule` wrappers on supported modules. Args: model (nn.Module): - The PyTorch model to be analyzed. + The PyTorch model to be prepared for analysis. task (Task): - The specific task associated with the model. + The specific task associated with the model, used for `TrackedModule` installation. Returns: nn.Module: - The PyTorch model with `param.requires_grad = False` on all modules and with `TrackedModule` installed. + The prepared model with non-trainable parameters and `TrackedModule` wrappers. """ model.eval() for params in model.parameters(): params.requires_grad = False for buffers in model.buffers(): buffers.requires_grad = False - # Install `TrackedModule` to the model. + + # Install `TrackedModule` wrappers on supported modules. model = wrap_tracked_modules(model=model, task=task) return model class Analyzer(FactorComputer, ScoreComputer): - """Handles the computation of all factors (e.g., covariance and Lambda matrices for EKFAC) - and influence scores for a given PyTorch model.""" + """Handles the computation of factors (e.g., covariance and Lambda matrices for EKFAC) and + influence scores for a given PyTorch model.""" def __init__( self, @@ -61,33 +61,33 @@ def __init__( output_dir: str = "./influence_results", disable_model_save: bool = True, ) -> None: - """Initializes an instance of the Analyzer class. + """Initializes an instance of the `Analyzer` class. Args: analysis_name (str): - The unique identifier for the analysis, used to organize and retrieve the results. + Unique identifier for the analysis, used for organizing results. model (nn.Module): The PyTorch model to be analyzed. task (Task): The specific task associated with the model. cpu (bool, optional): - Specifies whether the analysis should be explicitly performed using the CPU. - Defaults to False, utilizing GPU resources if available. + If `True`, forces analysis to be performed on CPU. Defaults to `False`. log_level (int, optional): - The logging level to use (e.g., logging.DEBUG, logging.INFO). Defaults to the root logging level. + Logging level (e.g., logging.DEBUG, logging.INFO). Defaults to root logging level. log_main_process_only (bool, optional): - If True, restricts logging to the main process. Defaults to True. + If `True`, restricts logging to the main process. Defaults to `True`. profile (bool, optional): - Enables the generation of performance profiling logs. This can be useful for - identifying bottlenecks or performance issues. Defaults to False. + If `True`, enables performance profiling logs. Defaults to `False`. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + If `True`, disables TQDM progress bars. Defaults to `False`. output_dir (str): - The file path to the directory, where analysis results will be stored. If the directory - does not exist, it will be created. Defaults to './influence_results'. + Directory path for storing analysis results. Defaults to './influence_results'. disable_model_save (bool, optional): - If set to True, prevents the saving of the model's state_dict. When the provided model is different - from the previously saved model, it will raise an Exception. Defaults to True. + If `True`, prevents saving the model's state_dict. Defaults to `True`. + + Raises: + ValueError: + If the provided model differs from a previously saved model when `disable_model_save` is `False`. """ super().__init__( name=analysis_name, @@ -103,17 +103,17 @@ def __init__( self.logger.info(f"Initializing Computer with parameters: {locals()}") self.logger.debug(f"Process state configuration:\n{repr(self.state)}") - # Saves model parameters. + # Save model parameters if necessary. if self.state.is_main_process and not disable_model_save: self._save_model() self.state.wait_for_everyone() def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None: - """Sets the default DataLoader parameters to use for all DataLoaders. + """Sets the default DataLoader arguments. Args: dataloader_kwargs (DataLoaderKwargs): - The object containing parameters for DataLoader. + The object containing arguments for DataLoader. """ self._dataloader_params = dataloader_kwargs @@ -125,7 +125,7 @@ def _save_model(self) -> None: if model_save_path.exists(): self.logger.info(f"Found existing saved model at `{model_save_path}`.") - # Load the existing model's state_dict for comparison. + # Load existing model's `state_dict` for comparison. loaded_state_dict = load_file(model_save_path) if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()): error_msg = ( @@ -152,27 +152,24 @@ def fit_all_factors( factor_args: Optional[FactorArguments] = None, overwrite_output_dir: bool = False, ) -> None: - """Computes all necessary factors for the given factor strategy. As an example, EK-FAC - requires (1) computing covariance matrices, (2) performing Eigendecomposition, and - (3) computing Lambda (corrected eigenvalues) matrices. + """Computes all necessary factors for the given factor strategy. Args: factors_name (str): - The unique identifier for the factor, used to organize and retrieve the results. + Unique identifier for the factor, used for organizing results. dataset (data.Dataset): - The dataset that will be used to fit all the factors. + Dataset used to fit all the factors. per_device_batch_size (int, optional): - The per-device batch size used to fit the factors. If not specified, executable - per-device batch size is automatically determined. + Per-device batch size for factor fitting. If not specified, executable per-device batch size + is automatically determined. initial_per_device_batch_size_attempt (int): - The initial attempted per-device batch size when the batch size is not provided. + Initial batch size attempt when `per_device_batch_size` is not explicitly provided. Defaults to `4096`. dataloader_kwargs (DataLoaderKwargs, optional): - Controls additional arguments for PyTorch's DataLoader. + Additional arguments for PyTorch's DataLoader. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, - the default values of `FactorArguments` will be used. + Arguments for factor computation. Defaults to `FactorArguments` default values. overwrite_output_dir (bool, optional): - If True, the existing factors with the same name will be overwritten. + If `True`, overwrites existing factors with the same name. Defaults to `False`. """ self.fit_covariance_matrices( factors_name=factors_name, @@ -200,36 +197,41 @@ def fit_all_factors( @staticmethod def load_file(path: Union[str, Path]) -> Dict[str, torch.Tensor]: - """Loads the `.safetensors` file at the given path from disk. - - See https://github.com/huggingface/safetensors. + """Loads a `safetensors` file from the given path. Args: path (Path): - The path to the `.safetensors` file. + The path to the `safetensors` file. Returns: Dict[str, torch.Tensor]: - The contents of the file, which is the dictionary mapping string to tensors. + Dictionary mapping strings to tensors from the loaded file. + + Raises: + FileNotFoundError: + If the specified file does not exist. + + Note: + For more information on safetensors, see https://github.com/huggingface/safetensors. """ if isinstance(path, str): path = Path(path).resolve() if not path.exists(): - raise FileNotFoundError(f"File does not exists at `{path}`.") + raise FileNotFoundError(f"File not found: {path}.") return load_file(path) @staticmethod def get_module_summary(model: nn.Module) -> str: - """Returns the formatted summary of the modules in model. Useful identifying which modules to - compute influence scores. + """Generates a formatted summary of the model's modules, excluding those without parameters. This summary is + useful for identifying which modules to compute influence scores for. Args: model (nn.Module): - The PyTorch model to be investigated. + The PyTorch model to be summarized. Returns: str: - The formatted string summary of the model. + A formatted string containing the model summary, including module names and representations. """ format_str = "==Model Summary==" for module_name, module in model.named_modules(): diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index ef64f5d..8625993 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -10,7 +10,12 @@ class Arguments: """Base class for specifying arguments for computing factors and influence scores.""" def to_dict(self) -> Dict[str, Any]: - """Converts the arguments to a dictionary.""" + """Converts the arguments to a dictionary. + + Returns: + Dict[str, Any]: + A dictionary representation of the arguments, with `torch.dtype` values converted to strings. + """ config = copy.deepcopy(self.__dict__) for key, value in config.items(): if isinstance(value, torch.dtype): @@ -18,7 +23,12 @@ def to_dict(self) -> Dict[str, Any]: return config def to_str_dict(self) -> Dict[str, str]: - """Converts the arguments to a dictionary, where all values are converted to strings.""" + """Converts the arguments to a dictionary with all values as strings. + + Returns: + Dict[str, str]: + A dictionary representation of the arguments, with all values converted to strings. + """ config = copy.deepcopy(self.__dict__) for key, value in config.items(): config[key] = str(value) @@ -29,25 +39,18 @@ def to_str_dict(self) -> Dict[str, str]: class FactorArguments(Arguments): """Arguments for computing influence factors.""" - # General configuration. # + # General configuration # strategy: str = field( default="ekfac", - metadata={ - "help": "Specifies the algorithm for computing influence factors. Default is 'ekfac' " - "(Eigenvalue-corrected Kronecker-factored Approximate Curvature)." - }, + metadata={"help": "Specifies the algorithm for computing influence factors. Default is 'ekfac'."}, ) use_empirical_fisher: bool = field( default=False, metadata={ - "help": "Determines whether to approximate empirical Fisher (using true labels) or " - "true Fisher (using sampled labels)." + "help": "If `True`, approximates empirical Fisher (using true labels) instead of " + "true Fisher (using sampled labels from model's outputs)." }, ) - distributed_sync_interval: int = field( - default=1_000, - metadata={"help": "Number of iterations between synchronization steps in distributed computing settings."}, - ) amp_dtype: Optional[torch.dtype] = field( default=None, metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."}, @@ -60,10 +63,7 @@ class FactorArguments(Arguments): # Configuration for fitting covariance matrices. # covariance_max_examples: Optional[int] = field( default=100_000, - metadata={ - "help": "Maximum number of examples to use when fitting covariance matrices. " - "Uses entire dataset if `None`." - }, + metadata={"help": "Maximum number of examples for fitting covariance matrices. Uses entire dataset if `None`."}, ) covariance_data_partitions: int = field( default=1, @@ -78,28 +78,26 @@ class FactorArguments(Arguments): ) activation_covariance_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for activation covariance computations."}, + metadata={"help": "Data type for activation covariance computation."}, ) gradient_covariance_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for pseudo-gradient covariance computations."}, + metadata={"help": "Data type for pseudo-gradient covariance computation."}, ) - # Configuration for performing eigendecomposition. # + # Configuration for performing eigendecomposition # eigendecomposition_dtype: torch.dtype = field( default=torch.float64, metadata={ - "help": "Data type for eigendecomposition computations. Double precision (`torch.float64`) is " - "recommended for numerical stability." + "help": "Data type for eigendecomposition. Double precision (`torch.float64`) is recommended " + "for numerical stability." }, ) - # Configuration for fitting Lambda matrices. # + # Configuration for fitting Lambda matrices # lambda_max_examples: Optional[int] = field( default=100_000, - metadata={ - "help": "Maximum number of examples to use when fitting Lambda matrices. Uses entire dataset if `None`." - }, + metadata={"help": "Maximum number of examples for fitting Lambda matrices. Uses entire dataset if `None`."}, ) lambda_data_partitions: int = field( default=1, @@ -120,60 +118,68 @@ class FactorArguments(Arguments): ) offload_activations_to_cpu: bool = field( default=False, - metadata={ - "help": "If `True`, offloads cached activations to CPU memory when computing " - "per-sample gradients, reducing GPU memory usage." - }, + metadata={"help": "If `True`, offloads cached activations to CPU memory when computing per-sample gradients."}, ) per_sample_gradient_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for per-sample gradient computations."}, + metadata={"help": "Data type for per-sample gradient computation."}, ) lambda_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for Lambda matrix computations."}, + metadata={"help": "Data type for Lambda matrix computation."}, ) + def __post_init__(self) -> None: + if self.covariance_max_examples is not None and self.covariance_max_examples <= 0: + raise ValueError("`covariance_max_examples` must be `None` or positive.") + + if self.lambda_max_examples is not None and self.lambda_max_examples <= 0: + raise ValueError("`lambda_max_examples` must be `None` or positive.") + + if any( + partition <= 0 + for partition in [ + self.covariance_data_partitions, + self.covariance_module_partitions, + self.lambda_data_partitions, + self.lambda_module_partitions, + ] + ): + raise ValueError("All data and module partitions must be positive.") + @dataclass class ScoreArguments(Arguments): """Arguments for computing influence scores.""" - # General configuration. # + # General configuration # damping_factor: Optional[float] = field( default=1e-08, metadata={ - "help": "Damping factor for the inverse Hessian-vector product (iHVP). " + "help": "Damping factor for inverse Hessian-vector product (iHVP). " "If `None`, uses a heuristic of 0.1 times the mean eigenvalue." }, ) - distributed_sync_interval: int = field( - default=1_000, - metadata={"help": "Number of iterations between synchronization steps in distributed computing settings."}, - ) amp_dtype: Optional[torch.dtype] = field( default=None, metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."}, ) offload_activations_to_cpu: bool = field( default=False, - metadata={ - "help": "If `True`, offloads cached activations to CPU memory when computing " - "per-sample gradients, reducing GPU memory usage." - }, + metadata={"help": "If `True`, offloads cached activations to CPU memory when computing per-sample gradients."}, ) einsum_minimize_size: bool = field( - default=False, + default=True, metadata={ "help": "If `True`, einsum operations find the contraction that minimizes the size of the " "largest intermediate tensor." }, ) - # Partition configuration. # + # Partition configuration # data_partitions: int = field( default=1, - metadata={"help": "Number of partitions to divide the dataset into for influence score computation."}, + metadata={"help": "Number of partitions to divide the dataset for influence score computation."}, ) module_partitions: int = field( default=1, @@ -182,20 +188,20 @@ class ScoreArguments(Arguments): }, ) - # General score configuration. # + # General score configuration # compute_per_module_scores: bool = field( default=False, - metadata={"help": "If `True`, computes separate scores for each module instead of summing across all modules."}, + metadata={"help": "If `True`, computes separate scores for each module instead of summing across all."}, ) compute_per_token_scores: bool = field( default=False, metadata={ - "help": "If `True`, computes separate scores for each token instead of summing across all tokens. " + "help": "If `True`, computes separate scores for each token instead of summing across all. " "Only applicable to transformer-based models." }, ) - # Pairwise influence score configuration. # + # Pairwise influence score configuration # query_gradient_accumulation_steps: int = field( default=1, metadata={"help": "Number of query batches to accumulate before processing training examples."}, @@ -224,31 +230,44 @@ class ScoreArguments(Arguments): aggregate_train_gradients: bool = field( default=False, metadata={ - "help": "If `True`, uses the summed train gradient instead of per-sample train gradients " + "help": "If `True`, uses the summed training gradient instead of per-sample training gradients " "for pairwise influence computation." }, ) - # Self-influence score configuration. # + # Self-influence score configuration # use_measurement_for_self_influence: bool = field( default=False, metadata={"help": "If `True`, uses the measurement (instead of the loss) for computing self-influence scores."}, ) - # Precision configuration. # + # Precision configuration # query_gradient_svd_dtype: torch.dtype = field( default=torch.float32, metadata={"help": "Data type for singular value decomposition (SVD) of query gradient."}, ) - score_dtype: torch.dtype = field( - default=torch.float32, - metadata={"help": "Data type for computing and storing influence scores."}, - ) per_sample_gradient_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for computing per-sample gradients."}, + metadata={"help": "Data type for per-sample gradient computation."}, ) precondition_dtype: torch.dtype = field( default=torch.float32, - metadata={"help": "Data type for computing the preconditioned gradient."}, + metadata={"help": "Data type for preconditioned gradient computation."}, ) + score_dtype: torch.dtype = field( + default=torch.float32, + metadata={"help": "Data type for influence score computation."}, + ) + + def __post_init__(self) -> None: + if self.damping_factor is not None and self.damping_factor <= 0: + raise ValueError("`damping_factor` must be None or positive.") + + if any(partition <= 0 for partition in [self.data_partitions, self.module_partitions]): + raise ValueError("Both data and module partitions must be positive.") + + if self.query_gradient_accumulation_steps <= 0: + raise ValueError("`query_gradient_accumulation_steps` must be positive.") + + if self.query_gradient_low_rank is not None and self.query_gradient_low_rank <= 0: + raise ValueError("`query_gradient_low_rank` must be None or positive.") diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index 9abd0d3..62752af 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -70,10 +70,10 @@ def __init__( profile: bool = False, disable_tqdm: bool = False, ) -> None: - """Initializes an instance of the Computer class.""" + """Initializes an instance of the `Computer` class. See `Analyzer` for more information.""" self.state = State(cpu=cpu) - # Creates and configures logger. + # Create and configure logger. disable_log = log_main_process_only and self.state.process_index != 0 self.logger = get_logger(name=__name__, log_level=log_level, disable_log=disable_log) @@ -84,19 +84,17 @@ def __init__( if len(tracked_module_names) == 0: error_msg = ( f"No tracked modules found in the provided model: {self.model}. " - f"Please make sure to run `prepare_model` before passing it in to the " - f"Analyzer." + f"Please ensure you've run `prepare_model` before passing it to the Analyzer." ) self.logger.error(error_msg) raise TrackedModuleNotFoundError(error_msg) self.logger.info(f"Tracking modules with names: {tracked_module_names}.") if self.state.use_distributed and not isinstance(model, (DDP, FSDP)): - warning_msg = ( - "Creating a DDP module. If specific configuration needs to be used " - "for DDP, please pass in the model after the manual DDP wrapping." + self.logger.warning( + "Creating a DDP module. For custom DDP configuration, " + "please manually wrap the model with DDP before passing it in." ) - self.logger.warning(warning_msg) self.model.to(self.state.device) self.model = DDP( self.model, @@ -105,22 +103,25 @@ def __init__( ) if cpu and isinstance(model, (DataParallel, DDP, FSDP)): - error_msg = "To enforce CPU, the model should not be wrapped with DP, DDP, or FSDP." + error_msg = ( + "CPU enforcement is incompatible with DP, DDP, or FSDP wrapped models. " + "Please provide an unwrapped model when using `cpu=True`." + ) self.logger.error(error_msg) raise ValueError(error_msg) if not self.state.use_distributed: self.model.to(self.state.device) - # Creates and configures output directory. + # Create and configure output directory. self.output_dir = Path(output_dir).joinpath(name).resolve() os.makedirs(name=self.output_dir, exist_ok=True) - # Creates and configures profiler. + # Create and configure profiler. self.profiler = Profiler(state=self.state) if profile else PassThroughProfiler(state=self.state) self.disable_tqdm = disable_tqdm - # Sets PyTorch DataLoader arguments. + # Set PyTorch DataLoader arguments. self._dataloader_params = DataLoaderKwargs() def factors_output_dir(self, factors_name: str) -> Path: @@ -145,11 +146,11 @@ def _save_arguments( loaded_arguments = load_json(arguments_save_path) if loaded_arguments != arguments.to_dict(): error_msg = ( - "Attempting to use the arguments that differs from the one already saved. " - "Please set `overwrite_output_dir=True` to overwrite existing experiment." + f"New arguments differ from saved arguments at `{arguments_save_path}`. " + "Set `overwrite_output_dir=True` to overwrite existing experiment.\n" + f"New arguments: {arguments.to_dict()}\n" + f"Saved arguments: {loaded_arguments}" ) - error_msg += f"\nNew arguments: {arguments.to_dict()}." - error_msg += f"\nSaved arguments: {loaded_arguments}." self.logger.error(error_msg) raise ValueError(error_msg) else: @@ -198,13 +199,13 @@ def _get_dataloader( allow_duplicates: bool = False, stack: bool = False, ) -> data.DataLoader: - """Returns the DataLoader for the given dataset, per_device_batch_size, and additional parameters.""" + """Returns the DataLoader with the provided configuration.""" if indices is not None: dataset = data.Subset(dataset=dataset, indices=indices) if self.state.use_distributed and not allow_duplicates: if stack: - error_msg = "DistributedEvalSampler is not currently supported with `stack=True`." + error_msg = "DistributedEvalSampler is incompatible with `stack=True`." self.logger.error(error_msg) raise ValueError(error_msg) sampler = DistributedEvalSampler( @@ -255,9 +256,8 @@ def _get_data_partition( """Partitions the dataset into several chunks.""" if total_data_examples < data_partitions: error_msg = ( - f"Data partition size ({data_partitions}) cannot be greater than the " - f"total data points ({total_data_examples}). Please reduce the data partition " - f"size in the argument." + f"Data partition size ({data_partitions}) exceeds total data points ({total_data_examples}). " + "Please reduce the data partition size." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -274,10 +274,7 @@ def _get_data_partition( for data_partition in target_data_partitions: if data_partition < 0 or data_partition > data_partitions: - error_msg = ( - f"Invalid data partition {data_partition} encountered. " - f"The module partition needs to be in between [0, {data_partitions})." - ) + error_msg = f"Invalid data partition {data_partition}. Must be in range [0, {data_partitions})." self.logger.error(error_msg) raise ValueError(error_msg) @@ -293,9 +290,8 @@ def _get_module_partition( if len(tracked_module_names) < module_partitions: error_msg = ( - f"Module partition size ({module_partitions}) cannot be greater than the " - f"total tracked modules ({len(tracked_module_names)}). Please reduce the module partition " - f"size in the argument." + f"Module partition size ({module_partitions}) exceeds total tracked modules " + f"({len(tracked_module_names)}). Please reduce the module partition size." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -313,10 +309,7 @@ def _get_module_partition( for module_partition in target_module_partitions: if module_partition < 0 or module_partition > module_partitions: - error_msg = ( - f"Invalid module partition {module_partition} encountered. " - f"The module partition needs to be in between [0, {module_partitions})." - ) + error_msg = f"Invalid module partition {module_partition}. Must be in range [0, {module_partitions})." self.logger.error(error_msg) raise ValueError(error_msg) @@ -325,7 +318,7 @@ def _get_module_partition( def _reset_memory(self) -> None: """Clears all cached memory.""" self.model.zero_grad(set_to_none=True) - set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) + set_mode(model=self.model, mode=ModuleMode.DEFAULT, release_memory=True) release_memory() def _log_profile_summary(self, name: str) -> None: @@ -402,7 +395,7 @@ def load_all_factors(self, factors_name: str) -> FACTOR_TYPE: if factor_args is None: error_msg = f"Factors with name `{factors_name}` was not found at `{factors_output_dir}`." self.logger.error(error_msg) - raise ValueError(error_msg) + raise FileNotFoundError(error_msg) loaded_factors: FACTOR_TYPE = {} factor_config = FactorConfig.CONFIGS[factor_args.strategy] diff --git a/kronfluence/computer/factor_computer.py b/kronfluence/computer/factor_computer.py index af67eba..22e5ff0 100644 --- a/kronfluence/computer/factor_computer.py +++ b/kronfluence/computer/factor_computer.py @@ -39,7 +39,7 @@ class FactorComputer(Computer): def _configure_and_save_factor_args( self, factor_args: Optional[FactorArguments], factors_output_dir: Path, overwrite_output_dir: bool ) -> FactorArguments: - """Configures the provided factor arguments and saves it in disk.""" + """Configures and saves factor arguments to disk.""" if factor_args is None: factor_args = FactorArguments() self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") @@ -69,7 +69,7 @@ def _aggregate_factors( """Aggregates factors computed for all data and module partitions.""" factors_output_dir = self.factors_output_dir(factors_name=factors_name) if not factors_output_dir.exists(): - error_msg = f"Factors directory `{factors_output_dir}` is not found when trying to aggregate factors." + error_msg = f"Factors directory `{factors_output_dir}` not found when trying to aggregate factors." self.logger.error(error_msg) raise FileNotFoundError(error_msg) @@ -78,7 +78,7 @@ def _aggregate_factors( exist_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions ) if not all_partition_exists: - self.logger.warning("Factors are not aggregated as factors for some partitions are not yet computed.") + self.logger.info("Factors are not aggregated as factors for some partitions are not yet computed.") return start_time = time.time() @@ -95,8 +95,9 @@ def _aggregate_factors( for module_name in factors: if module_name not in aggregated_factors[factor_name]: - aggregated_factors[factor_name][module_name] = torch.zeros( - size=factors[module_name].shape, dtype=factors[module_name].dtype, requires_grad=False + aggregated_factors[factor_name][module_name] = torch.zeros_like( + factors[module_name], + requires_grad=False, ) aggregated_factors[factor_name][module_name].add_(factors[module_name]) del loaded_factors @@ -120,7 +121,7 @@ def _find_executable_factors_batch_size( """Automatically finds executable batch size for performing `func`.""" if self.state.use_distributed: error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " + "Automatic batch size search is not supported for multi-GPU setting. " "Please manually configure the batch size by passing in `per_device_batch_size`." ) self.logger.error(error_msg) @@ -175,7 +176,7 @@ def fit_covariance_matrices( factors_name (str): The unique identifier for the factor, used to organize and retrieve the results. dataset (data.Dataset): - The dataset that will be used to fit covariance matrices. + The dataset that will be used for fitting covariance matrices. per_device_batch_size (int, optional): The per-device batch size used to fit the factors. If not specified, executable batch size is automatically determined. @@ -184,8 +185,7 @@ def fit_covariance_matrices( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, the default values of - `FactorArguments` will be used. + Arguments for factor computation. target_data_partitions(Sequence[int], int, optional): The list of data partition to fit covariance matrices. By default, covariance matrices will be computed for all partitions. @@ -193,7 +193,7 @@ def fit_covariance_matrices( The list of module partition to fit covariance matrices. By default, covariance matrices will be computed for all partitions. overwrite_output_dir (bool, optional): - If True, the existing factors with the same `factors_name` will be overwritten. + Whether to overwrite existing output. """ self.logger.debug(f"Fitting covariance matrices with parameters: {locals()}") @@ -250,10 +250,7 @@ def fit_covariance_matrices( ) if max_partition_examples < self.state.num_processes: - error_msg = ( - "The number of processes are larger than the total data examples. " - "Try reducing the number of processes." - ) + error_msg = "The number of processes are larger than total data examples. Try reducing number of processes." self.logger.error(error_msg) raise ValueError(error_msg) @@ -366,7 +363,7 @@ def aggregate_covariance_matrices( if factor_args is None: error_msg = ( f"Arguments for factors with name `{factors_name}` was not found when trying to " - f"aggregated covariance matrices." + f"aggregate covariance matrices." ) self.logger.error(error_msg) raise ValueError(error_msg) @@ -388,26 +385,25 @@ def perform_eigendecomposition( overwrite_output_dir: bool = False, load_from_factors_name: Optional[str] = None, ) -> None: - """Performs Eigendecomposition on all covariance matrices. + """Performs eigendecomposition on all covariance matrices. Args: factors_name (str): The unique identifier for the factor, used to organize and retrieve the results. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, the default values of - `FactorArguments` will be used. + Arguments for factor computation. overwrite_output_dir (bool, optional): - If True, the existing factors with the same `factors_name` will be overwritten. + Whether to overwrite existing output. load_from_factors_name (str, optional): The `factor_name` to load covariance matrices from. By default, covariance matrices with the same `factor_name` will be used. """ - self.logger.debug(f"Performing Eigendecomposition with parameters: {locals()}") + self.logger.debug(f"Performing eigendecomposition with parameters: {locals()}") factors_output_dir = self.factors_output_dir(factors_name=factors_name) os.makedirs(factors_output_dir, exist_ok=True) if eigendecomposition_exist(output_dir=factors_output_dir) and not overwrite_output_dir: - self.logger.info(f"Found existing Eigendecomposition results at `{factors_output_dir}`. Skipping.") + self.logger.info(f"Found existing eigendecomposition results at `{factors_output_dir}`. Skipping.") return factor_args = self._configure_and_save_factor_args( @@ -416,7 +412,7 @@ def perform_eigendecomposition( if not FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition: self.logger.info( - f"Strategy `{factor_args.strategy}` does not require performing Eigendecomposition. Skipping." + f"Strategy `{factor_args.strategy}` does not require performing eigendecomposition. Skipping." ) return None @@ -428,7 +424,7 @@ def perform_eigendecomposition( if not covariance_matrices_exist(output_dir=load_factors_output_dir): error_msg = ( f"Covariance matrices not found at `{load_factors_output_dir}`. " - f"To perform Eigendecomposition, covariance matrices need to be first computed." + f"To perform eigendecomposition, covariance matrices need to be first computed." ) self.logger.error(error_msg) raise FactorsNotFoundError(error_msg) @@ -463,13 +459,13 @@ def perform_eigendecomposition( ) end_time = time.time() elapsed_time = end_time - start_time - self.logger.info(f"Performed Eigendecomposition in {elapsed_time:.2f} seconds.") + self.logger.info(f"Performed eigendecomposition in {elapsed_time:.2f} seconds.") with self.profiler.profile("Save Eigendecomposition"): save_eigendecomposition( output_dir=factors_output_dir, factors=eigen_factors, metadata=factor_args.to_str_dict() ) - self.logger.info(f"Saved Eigendecomposition results at `{factors_output_dir}`.") + self.logger.info(f"Saved eigendecomposition results at `{factors_output_dir}`.") self.state.wait_for_everyone() self._log_profile_summary(name=f"factors_{factors_name}_eigendecomposition") @@ -492,7 +488,7 @@ def fit_lambda_matrices( factors_name (str): The unique identifier for the factor, used to organize and retrieve the results. dataset (data.Dataset): - The dataset that will be used to fit Lambda matrices. + The dataset that will be used for fitting Lambda matrices. per_device_batch_size (int, optional): The per-device batch size used to fit the factors. If not specified, executable batch size is automatically determined. @@ -501,8 +497,7 @@ def fit_lambda_matrices( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. factor_args (FactorArguments, optional): - Arguments related to computing the factors. If not specified, the default values of - `FactorArguments` will be used. + Arguments for factor computation. target_data_partitions(Sequence[int], int, optional): The list of data partition to fit Lambda matrices. By default, Lambda matrices will be computed for all partitions. @@ -510,9 +505,9 @@ def fit_lambda_matrices( The list of module partition to fit Lambda matrices. By default, Lambda matrices will be computed for all partitions. overwrite_output_dir (bool, optional): - If True, the existing factors with the same `factors_name` will be overwritten. + Whether to overwrite existing output. load_from_factors_name (str, optional): - The `factor_name` to load Eigendecomposition results from. By default, Eigendecomposition + The `factor_name` to load eigendecomposition results from. By default, eigendecomposition results with the same `factor_name` will be used. """ self.logger.debug(f"Fitting Lambda matrices with parameters: {locals()}") @@ -542,7 +537,7 @@ def fit_lambda_matrices( if load_from_factors_name is not None: self.logger.info( - f"Will be loading Eigendecomposition results from factors with name `{load_from_factors_name}`." + f"Will be loading eigendecomposition results from factors with name `{load_from_factors_name}`." ) load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) else: @@ -554,7 +549,7 @@ def fit_lambda_matrices( ): error_msg = ( f"Eigendecomposition results not found at `{load_factors_output_dir}`. " - f"To fit Lambda matrices for `{factor_args.strategy}`, Eigendecomposition must be " + f"To fit Lambda matrices for `{factor_args.strategy}`, eigendecomposition must be " f"performed before computing Lambda matrices." ) self.logger.error(error_msg) @@ -604,10 +599,7 @@ def fit_lambda_matrices( ) if max_partition_examples < self.state.num_processes: - error_msg = ( - "The number of processes are larger than the total data examples. " - "Try reducing the number of processes." - ) + error_msg = "The number of processes are larger than total data examples. Try reducing number of processes." self.logger.error(error_msg) raise ValueError(error_msg) @@ -722,7 +714,7 @@ def aggregate_lambda_matrices( if factor_args is None: error_msg = ( f"Arguments for factors with name `{factors_name}` was not found when trying " - f"to aggregated Lambda matrices." + f"to aggregate Lambda matrices." ) self.logger.error(error_msg) raise ValueError(error_msg) diff --git a/kronfluence/computer/score_computer.py b/kronfluence/computer/score_computer.py index 018f2e9..d2568d6 100644 --- a/kronfluence/computer/score_computer.py +++ b/kronfluence/computer/score_computer.py @@ -39,7 +39,7 @@ def _configure_and_save_score_args( factors_name: str, overwrite_output_dir: bool, ) -> Tuple[FactorArguments, ScoreArguments]: - """Configures the provided score arguments and saves it in disk.""" + """Configures and saves score arguments to disk.""" if score_args is None: score_args = ScoreArguments() self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.") @@ -49,7 +49,7 @@ def _configure_and_save_score_args( factor_args = self.load_factor_args(factors_name=factors_name) factors_output_dir = self.factors_output_dir(factors_name=factors_name) if factor_args is None: - error_msg = f"Factors with name `{factors_name}` was not found at `{factors_output_dir}`." + error_msg = f"Factors with name `{factors_name}` not found at `{factors_output_dir}`." self.logger.error(error_msg) raise FactorsNotFoundError(error_msg) self.logger.info(f"Loaded `FactorArguments` with configuration: {factor_args}.") @@ -83,10 +83,7 @@ def _aggregate_scores( """Aggregates influence scores computed for all data and module partitions.""" scores_output_dir = self.scores_output_dir(scores_name=scores_name) if not scores_output_dir.exists(): - error_msg = ( - f"Scores output directory `{scores_output_dir}` is not found " - f"when trying to aggregate partitioned scores." - ) + error_msg = f"Scores directory `{scores_output_dir}` not found when trying to aggregate scores." self.logger.error(error_msg) raise FileNotFoundError(error_msg) @@ -97,31 +94,32 @@ def _aggregate_scores( exist_fnc(output_dir=scores_output_dir, partition=partition) for partition in all_required_partitions ) if not all_partition_exists: - self.logger.info("Influence scores are not aggregated as scores for some partitions are not yet computed.") + self.logger.info("Scores are not aggregated as scores for some partitions are not yet computed.") return start_time = time.time() aggregated_scores: SCORE_TYPE = {} - with self.profiler.profile("Aggregate Score"): - for data_partition in range(score_args.data_partitions): - aggregated_module_scores = {} + for data_partition in range(score_args.data_partitions): + aggregated_module_scores = {} - for module_partition in range(score_args.module_partitions): - loaded_scores = load_fnc( - output_dir=scores_output_dir, - partition=(data_partition, module_partition), - ) + for module_partition in range(score_args.module_partitions): + loaded_scores = load_fnc( + output_dir=scores_output_dir, + partition=(data_partition, module_partition), + ) - for module_name, scores in loaded_scores.items(): - if module_name not in aggregated_module_scores: - aggregated_module_scores[module_name] = scores - else: - aggregated_module_scores[module_name].add_(scores) - del loaded_scores + for module_name, scores in loaded_scores.items(): + if module_name not in aggregated_module_scores: + aggregated_module_scores[module_name] = torch.zeros_like(scores, requires_grad=False) + aggregated_module_scores[module_name].add_(scores) + del loaded_scores - for module_name, scores in aggregated_module_scores.items(): - if module_name not in aggregated_scores: - aggregated_scores[module_name] = scores + for module_name, scores in aggregated_module_scores.items(): + if module_name not in aggregated_scores: + aggregated_scores[module_name] = scores.clone() + else: + if score_args.aggregate_train_gradients: + aggregated_scores[module_name].add_(scores) else: aggregated_scores[module_name] = torch.cat( ( @@ -130,10 +128,10 @@ def _aggregate_scores( ), dim=dim, ) - save_fnc(output_dir=scores_output_dir, scores=aggregated_scores, metadata=score_args.to_str_dict()) + save_fnc(output_dir=scores_output_dir, scores=aggregated_scores, metadata=score_args.to_str_dict()) end_time = time.time() elapsed_time = end_time - start_time - self.logger.info(f"Aggregated all partitioned scores in {elapsed_time:.2f} seconds.") + self.logger.info(f"Aggregated all scores in {elapsed_time:.2f} seconds.") return aggregated_scores def _find_executable_pairwise_scores_batch_size( @@ -152,7 +150,7 @@ def _find_executable_pairwise_scores_batch_size( """Automatically finds executable training batch size for computing pairwise influence scores.""" if self.state.use_distributed: error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " + "Automatic batch size search is not supported for multi-GPU setting. " "Please manually configure the batch size by passing in `per_device_batch_size`." ) self.logger.error(error_msg) @@ -187,7 +185,12 @@ def executable_batch_size_func(batch_size: int) -> None: allow_duplicates=True, stack=True, ) - compute_pairwise_scores_with_loaders( + func = ( + compute_pairwise_scores_with_loaders + if not score_args.aggregate_query_gradients + else compute_pairwise_query_aggregated_scores_with_loaders + ) + func( model=self.model, state=self.state, task=self.task, @@ -225,9 +228,7 @@ def compute_pairwise_scores( target_module_partitions: Optional[Sequence[int]] = None, overwrite_output_dir: bool = False, ) -> Optional[SCORE_TYPE]: - """Computes pairwise influence scores for the given score configuration. As an example, - for Q query dataset and T training dataset, the pairwise influence scores are - 2-dimensional matrix with dimension `Q x T`. + """Computes pairwise influence scores with the given score configuration. Args: scores_name (str): @@ -254,8 +255,7 @@ def compute_pairwise_scores( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. score_args (ScoreArguments, optional): - Arguments related to computing the pairwise scores. If not specified, the default values - of `ScoreArguments` will be used. + Arguments for score computation. target_data_partitions (Sequence[int], optional): Specific data partitions to compute influence scores. If not specified, scores for all data partitions will be computed. @@ -263,7 +263,7 @@ def compute_pairwise_scores( Specific module partitions to compute influence scores. If not specified, scores for all module partitions will be computed. overwrite_output_dir (bool, optional): - If True, the existing factors with the same name will be overwritten. + Whether to overwrite existing output. """ self.logger.debug(f"Computing pairwise scores with parameters: {locals()}") @@ -282,7 +282,7 @@ def compute_pairwise_scores( if score_args.compute_per_token_scores and score_args.aggregate_train_gradients: warning_msg = ( - "Token-wise influence computation is not compatible with `aggregate_train_gradients`. " + "Token-wise influence computation is not compatible with `aggregate_train_gradients=True`. " "Disabling `compute_per_token_scores`." ) score_args.compute_per_token_scores = False @@ -290,16 +290,16 @@ def compute_pairwise_scores( if score_args.compute_per_token_scores and factor_args.has_shared_parameters: warning_msg = ( - "Token-wise influence computation is not compatible with `has_shared_parameters`. " + "Token-wise influence computation is not compatible with `has_shared_parameters=True`. " "Disabling `compute_per_token_scores`." ) score_args.compute_per_token_scores = False self.logger.warning(warning_msg) - if score_args.compute_per_token_scores and self.task.do_post_process_per_sample_gradient: + if score_args.compute_per_token_scores and self.task.enable_post_process_per_sample_gradient: warning_msg = ( "Token-wise influence computation is not compatible with tasks that requires " - "`post_process_per_sample_gradient`. Disabling `compute_per_token_scores`." + "`enable_post_process_per_sample_gradient`. Disabling `compute_per_token_scores`." ) score_args.compute_per_token_scores = False self.logger.warning(warning_msg) @@ -377,7 +377,7 @@ def compute_pairwise_scores( start_index, end_index = data_partition_indices[data_partition] self.logger.info( - f"Fitting pairwise scores with data indices ({start_index}, {end_index}) and " + f"Computing pairwise scores with data indices ({start_index}, {end_index}) and " f"modules {module_partition_names[module_partition]}." ) @@ -471,19 +471,20 @@ def aggregate_pairwise_scores(self, scores_name: str) -> None: if score_args is None: error_msg = ( f"Arguments for scores with name `{score_args}` was not found when trying " - f"to aggregated pairwise influence scores." + f"to aggregate pairwise influence scores." ) self.logger.error(error_msg) raise ValueError(error_msg) - self._aggregate_scores( - scores_name=scores_name, - score_args=score_args, - exist_fnc=pairwise_scores_exist, - load_fnc=load_pairwise_scores, - save_fnc=save_pairwise_scores, - dim=1, - ) + with self.profiler.profile("Aggregate Score"): + self._aggregate_scores( + scores_name=scores_name, + score_args=score_args, + exist_fnc=pairwise_scores_exist, + load_fnc=load_pairwise_scores, + save_fnc=save_pairwise_scores, + dim=1, + ) def _find_executable_self_scores_batch_size( self, @@ -499,8 +500,8 @@ def _find_executable_self_scores_batch_size( """Automatically finds executable training batch size for computing self-influence scores.""" if self.state.use_distributed: error_msg = ( - "Automatic batch size search is currently not supported for multi-GPU training. " - "Please manually configure the batch size by passing in `per_device_train_batch_size`." + "Automatic batch size search is not supported for multi-GPU setting. " + "Please manually configure the batch size by passing in `per_device_batch_size`." ) self.logger.error(error_msg) raise NotImplementedError(error_msg) @@ -563,8 +564,7 @@ def compute_self_scores( target_module_partitions: Optional[Sequence[int]] = None, overwrite_output_dir: bool = False, ) -> Optional[SCORE_TYPE]: - """Computes self-influence scores for the given score configuration. As an example, - for training dataset with T examples, the self-influence scores are represented as T-dimensional vector. + """Computes self-influence scores with the given score configuration. Args: scores_name (str): @@ -584,8 +584,7 @@ def compute_self_scores( dataloader_kwargs (DataLoaderKwargs, optional): Controls additional arguments for PyTorch's DataLoader. score_args (ScoreArguments, optional): - Arguments related to computing the self-influence scores. If not specified, the default values - of `ScoreArguments` will be used. + Arguments for score computation. target_data_partitions (Sequence[int], optional): Specific data partitions to compute influence scores. If not specified, scores for all data partitions will be computed. @@ -593,7 +592,7 @@ def compute_self_scores( Specific module partitions to compute influence scores. If not specified, scores for all module partitions will be computed. overwrite_output_dir (bool, optional): - If True, the existing factors with the same name will be overwritten. + Whether to overwrite existing output. """ self.logger.debug(f"Computing self-influence scores with parameters: {locals()}") @@ -612,6 +611,7 @@ def compute_self_scores( if score_args.query_gradient_accumulation_steps != 1: warning_msg = "Query gradient accumulation is not supported for self-influence computation." + score_args.query_gradient_accumulation_steps = 1 self.logger.warning(warning_msg) if score_args.query_gradient_low_rank is not None: @@ -619,10 +619,18 @@ def compute_self_scores( "Low rank query gradient approximation is not supported for self-influence computation. " "No low rank query approximation will be performed." ) + score_args.query_gradient_low_rank = None self.logger.warning(warning_msg) if score_args.aggregate_query_gradients or score_args.aggregate_train_gradients: warning_msg = "Query or train gradient aggregation is not supported for self-influence computation." + score_args.aggregate_train_gradients = False + score_args.aggregate_query_gradients = False + self.logger.warning(warning_msg) + + if score_args.compute_per_token_scores: + warning_msg = "Token-wise influence computation is not compatible with self-influence scores. " + score_args.compute_per_token_scores = False self.logger.warning(warning_msg) dataloader_params = self._configure_dataloader(dataloader_kwargs) @@ -687,7 +695,7 @@ def compute_self_scores( start_index, end_index = data_partition_indices[data_partition] self.logger.info( - f"Fitting self-influence scores with data indices ({start_index}, {end_index}) and " + f"Computing self-influence scores with data indices ({start_index}, {end_index}) and " f"modules {module_partition_names[module_partition]}." ) @@ -768,11 +776,12 @@ def aggregate_self_scores(self, scores_name: str) -> None: if score_args is None: error_msg = ( f"Arguments for scores with name `{score_args}` was not found when trying " - f"to aggregated self-influence scores." + f"to aggregate self-influence scores." ) self.logger.error(error_msg) raise ValueError(error_msg) + score_args.aggregate_query_gradients = score_args.aggregate_train_gradients = False self._aggregate_scores( scores_name=scores_name, score_args=score_args, diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index 976a261..9863c02 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -40,32 +40,32 @@ def __init_subclass__(cls, factor_strategy: Optional[FactorStrategy] = None, **k @property @abstractmethod def requires_covariance_matrices(self) -> bool: - """Returns True if the strategy requires computing covariance matrices.""" + """Returns `True` if the strategy requires computing covariance matrices.""" raise NotImplementedError("Subclasses must implement the `requires_covariance_matrices` property.") @property @abstractmethod def requires_eigendecomposition(self) -> bool: - """Returns True if the strategy requires performing Eigendecomposition.""" + """Returns `True` if the strategy requires performing Eigendecomposition.""" raise NotImplementedError("Subclasses must implement the `requires_eigendecomposition` property.") @property @abstractmethod def requires_lambda_matrices(self) -> bool: - """Returns True if the strategy requires computing Lambda matrices.""" + """Returns `True` if the strategy requires computing Lambda matrices.""" raise NotImplementedError("Subclasses must implement the `requires_lambda_matrices` property.") @property @abstractmethod def requires_eigendecomposition_for_lambda(self) -> bool: - """Returns True if the strategy requires loading Eigendecomposition results, before computing + """Returns `True` if the strategy requires loading Eigendecomposition results, before computing Lambda matrices.""" raise NotImplementedError("Subclasses must implement the `requires_eigendecomposition_for_lambda` property.") @property @abstractmethod def requires_covariance_matrices_for_precondition(self) -> bool: - """Returns True if the strategy requires loading covariance matrices, before computing + """Returns `True` if the strategy requires loading covariance matrices, before computing preconditioned gradient.""" raise NotImplementedError( "Subclasses must implement the `requires_covariance_matrices_for_precondition` property." @@ -74,7 +74,7 @@ def requires_covariance_matrices_for_precondition(self) -> bool: @property @abstractmethod def requires_eigendecomposition_for_precondition(self) -> bool: - """Returns True if the strategy requires loading Eigendecomposition results, before computing + """Returns `True` if the strategy requires loading Eigendecomposition results, before computing preconditioned gradient.""" raise NotImplementedError( "Subclasses must implement the `requires_eigendecomposition_for_precondition` property." @@ -83,34 +83,42 @@ def requires_eigendecomposition_for_precondition(self) -> bool: @property @abstractmethod def requires_lambda_matrices_for_precondition(self) -> bool: - """Returns True if the strategy requires loading Lambda matrices, before computing + """Returns `True` if the strategy requires loading Lambda matrices, before computing the preconditioned gradient.""" raise NotImplementedError("Subclasses must implement the `requires_lambda_matrices_for_precondition` property.") + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + """Performs necessary operations before computing the preconditioned gradient. + + Args: + storage (STORAGE_TYPE): + A dictionary containing various factors required to compute the preconditioned gradient. + See `.storage` in `TrackedModule` for details. + score_args (ScoreArguments): + Arguments for computing the preconditioned gradient. + device (torch.device): + Device used for computing the preconditioned gradient. + """ + @abstractmethod def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - """Preconditions the per-sample-gradient. The per-sample-gradient is a 3-dimensional - tensor with the shape `batch_size x output_dim x input_dim`. + """Preconditions the per-sample gradient. The per-sample gradient is a 3-dimensional + tensor with shape `batch_size x output_dim x input_dim`. Args: gradient (torch.Tensor): - The per-sample-gradient tensor. + The per-sample gradient tensor. storage (STORAGE_TYPE): A dictionary containing various factors required to compute the preconditioned gradient. See `.storage` in `TrackedModule` for details. - damping (float, optional): - The damping factor when computing the preconditioned gradient. If not provided, sets - the damping term with some heuristic. Returns: torch.Tensor: - The preconditioned per-sample-gradient tensor. The dimension should be the same as the original - per-sample-gradient. + The preconditioned per-sample gradient tensor. """ raise NotImplementedError("Subclasses must implement the `precondition_gradient` method.") @@ -150,9 +158,8 @@ def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - del storage, damping + del storage return gradient @@ -187,18 +194,22 @@ def requires_eigendecomposition_for_precondition(self) -> bool: def requires_lambda_matrices_for_precondition(self) -> bool: return True + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED]) + damping_factor = score_args.damping_factor + if damping_factor is None: + damping_factor = 0.1 * torch.mean(lambda_matrix) + lambda_matrix.add_(damping_factor) + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu") + def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) - num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) - lambda_matrix = lambda_matrix / num_lambda_processed - if damping is None: - damping = 0.1 * torch.mean(lambda_matrix) - return gradient / (lambda_matrix + damping) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) + return gradient / lambda_matrix class Kfac(FactorConfig, factor_strategy=FactorStrategy.KFAC): @@ -235,27 +246,37 @@ def requires_eigendecomposition_for_precondition(self) -> bool: def requires_lambda_matrices_for_precondition(self) -> bool: return False + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to( + dtype=score_args.precondition_dtype + ) + storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to( + dtype=score_args.precondition_dtype + ) + activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device) + gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device) + lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) + damping_factor = score_args.damping_factor + if damping_factor is None: + damping_factor = 0.1 * torch.mean(lambda_matrix) + lambda_matrix.add_(damping_factor) + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu") + storage[ACTIVATION_EIGENVALUES_NAME] = None + storage[GRADIENT_EIGENVALUES_NAME] = None + @torch.no_grad() def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) - gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) - # The eigenvalues have the Kronecker structure for KFAC. - lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) - + activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(device=gradient.device) + gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - - if damping is None: - damping = 0.1 * torch.mean(lambda_matrix) - - gradient.div_(lambda_matrix + damping) - return torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) + gradient.div_(lambda_matrix) + gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) + return gradient class Ekfac(FactorConfig, factor_strategy=FactorStrategy.EKFAC): @@ -292,23 +313,34 @@ def requires_eigendecomposition_for_precondition(self) -> bool: def requires_lambda_matrices_for_precondition(self) -> bool: return True + def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: + storage[ACTIVATION_EIGENVECTORS_NAME] = storage[ACTIVATION_EIGENVECTORS_NAME].to( + dtype=score_args.precondition_dtype + ) + storage[GRADIENT_EIGENVECTORS_NAME] = storage[GRADIENT_EIGENVECTORS_NAME].to( + dtype=score_args.precondition_dtype + ) + storage[ACTIVATION_EIGENVALUES_NAME] = None + storage[GRADIENT_EIGENVALUES_NAME] = None + + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED]) + damping_factor = score_args.damping_factor + if damping_factor is None: + damping_factor = 0.1 * torch.mean(lambda_matrix) + lambda_matrix.add_(damping_factor) + storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu") + @torch.no_grad() def precondition_gradient( self, gradient: torch.Tensor, storage: STORAGE_TYPE, - damping: Optional[float] = None, ) -> torch.Tensor: - activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) - num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) - lambda_matrix = lambda_matrix / num_lambda_processed + activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(device=gradient.device) + gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - - if damping is None: - damping = 0.1 * torch.mean(lambda_matrix) - - gradient.div_(lambda_matrix + damping) + gradient.div_(lambda_matrix) gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) return gradient diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index 6e2f2f1..6133db8 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -18,12 +18,13 @@ set_attention_mask, set_gradient_scale, set_mode, - synchronize_covariance_matrices, + synchronize_modules, update_factor_args, ) from kronfluence.task import Task from kronfluence.utils.constants import ( COVARIANCE_FACTOR_NAMES, + DISTRIBUTED_SYNC_INTERVAL, FACTOR_TYPE, PARTITION_TYPE, ) @@ -36,7 +37,24 @@ def covariance_matrices_save_path( factor_name: str, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading covariance matrices.""" + """Generates the path for saving or loading covariance matrices. + + Args: + output_dir (Path): + Directory to save the matrices. + factor_name (str): + Name of the factor (must be in `COVARIANCE_FACTOR_NAMES`). + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the covariance matrix file. + + Raises: + AssertionError: + If `factor_name` is not in `COVARIANCE_FACTOR_NAMES`. + """ assert factor_name in COVARIANCE_FACTOR_NAMES if partition is not None: data_partition, module_partition = partition @@ -52,7 +70,22 @@ def save_covariance_matrices( partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves covariance matrices to disk.""" + """Saves covariance matrices to disk. + + Args: + output_dir (Path): + Directory to save the matrices. + factors (FACTOR_TYPE): + Dictionary of factors to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the factors. + + Raises: + AssertionError: + If factors keys don't match `COVARIANCE_FACTOR_NAMES`. + """ assert set(factors.keys()) == set(COVARIANCE_FACTOR_NAMES) for factor_name in factors: save_path = covariance_matrices_save_path( @@ -67,7 +100,18 @@ def load_covariance_matrices( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> FACTOR_TYPE: - """Loads covariance matrices from disk.""" + """Loads covariance matrices from disk. + + Args: + output_dir (Path): + Directory to load the matrices from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + FACTOR_TYPE: + Dictionary of loaded covariance factors. + """ covariance_factors = {} for factor_name in COVARIANCE_FACTOR_NAMES: save_path = covariance_matrices_save_path( @@ -83,7 +127,18 @@ def covariance_matrices_exist( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> bool: - """Checks if covariance matrices exist at the specified directory.""" + """Checks if covariance matrices exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + bool: + `True` if all covariance matrices exist, `False` otherwise. + """ for factor_name in COVARIANCE_FACTOR_NAMES: save_path = covariance_matrices_save_path( output_dir=output_dir, @@ -121,25 +176,22 @@ def fit_covariance_matrices_with_loader( A list of module names for which covariance matrices will be computed. If not specified, covariance matrices will be computed for all tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: Tuple[torch.Tensor, FACTOR_TYPE]: - A tuple containing the number of data points processed and computed covariance matrices. - The covariance matrices are organized in nested dictionaries, where the first key is the name of the - covariance matrix (e.g., activation covariance and pseudo-gradient covariance) and the second key is - the module name. + - Number of data points processed. + - Computed covariance matrices (nested dict: factor_name -> module_name -> tensor). """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - tracked_module_names=tracked_module_names, - mode=ModuleMode.COVARIANCE, - keep_factors=False, - ) + update_factor_args(model=model, factor_args=factor_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + tracked_module_names=tracked_module_names, + mode=ModuleMode.COVARIANCE, + release_memory=True, + ) total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) @@ -157,10 +209,10 @@ def fit_covariance_matrices_with_loader( ) as pbar: for index, batch in enumerate(loader): batch = send_to_device(batch, device=state.device) - with torch.no_grad(): - attention_mask = task.get_attention_mask(batch=batch) - if attention_mask is not None: - set_attention_mask(model=model, attention_mask=attention_mask) + + attention_mask = task.get_attention_mask(batch=batch) + if attention_mask is not None: + set_attention_mask(model=model, attention_mask=attention_mask) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) @@ -174,34 +226,38 @@ def fit_covariance_matrices_with_loader( if ( state.use_distributed - and total_steps % factor_args.distributed_sync_interval == 0 + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 and index not in [len(loader) - 1, len(loader) - 2] ): - # Periodically synchronizes all processes to avoid timeout at the final synchronization. state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) total_steps += 1 pbar.update(1) - with torch.no_grad(): - if state.use_distributed: - # Aggregates covariance matrices across multiple devices or nodes. - synchronize_covariance_matrices(model=model, tracked_module_names=tracked_module_names) - num_data_processed = num_data_processed.to(device=state.device) - dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) - - saved_factors: FACTOR_TYPE = {} - for factor_name in COVARIANCE_FACTOR_NAMES: - saved_factors[factor_name] = load_factors( - model=model, factor_name=factor_name, tracked_module_names=tracked_module_names, clone=False - ) - - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) - state.wait_for_everyone() + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + num_data_processed = num_data_processed.to(device=state.device) + dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + num_data_processed = num_data_processed.cpu() + + saved_factors: FACTOR_TYPE = {} + for factor_name in COVARIANCE_FACTOR_NAMES: + factor = load_factors( + model=model, + factor_name=factor_name, + tracked_module_names=tracked_module_names, + clone=True, + ) + if factor is None: + raise ValueError(f"Factor `{factor_name}` has not been computed.") + saved_factors[factor_name] = factor + + model.zero_grad(set_to_none=True) + set_attention_mask(model=model, attention_mask=None) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index 8f7980c..4f5c8da 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -14,13 +14,13 @@ from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( - finalize_lambda_matrices, + finalize_iteration, get_tracked_module_names, load_factors, set_factors, set_gradient_scale, set_mode, - synchronize_lambda_matrices, + synchronize_modules, update_factor_args, ) from kronfluence.task import Task @@ -28,6 +28,7 @@ ACTIVATION_COVARIANCE_MATRIX_NAME, ACTIVATION_EIGENVALUES_NAME, ACTIVATION_EIGENVECTORS_NAME, + DISTRIBUTED_SYNC_INTERVAL, EIGENDECOMPOSITION_FACTOR_NAMES, FACTOR_TYPE, GRADIENT_COVARIANCE_MATRIX_NAME, @@ -46,13 +47,41 @@ def eigendecomposition_save_path( output_dir: Path, factor_name: str, ) -> Path: - """Generates the path for saving/loading Eigendecomposition results.""" + """Generates the path for saving or loading eigendecomposition results. + + Args: + output_dir (Path): + Directory to save eigenvectors and eigenvalues. + factor_name (str): + Name of the factor (must be in `EIGENDECOMPOSITION_FACTOR_NAMES`). + + Returns: + Path: + The full path for the eigendecomposition file. + + Raises: + AssertionError: + If `factor_name` is not in `EIGENDECOMPOSITION_FACTOR_NAMES`. + """ assert factor_name in EIGENDECOMPOSITION_FACTOR_NAMES return output_dir / f"{factor_name}.safetensors" def save_eigendecomposition(output_dir: Path, factors: FACTOR_TYPE, metadata: Optional[Dict[str, str]] = None) -> None: - """Saves Eigendecomposition results to disk.""" + """Saves eigendecomposition results to disk. + + Args: + output_dir (Path): + Directory to save the eigenvectors and eigenvalues. + factors (FACTOR_TYPE): + Dictionary of factors to save. + metadata (Dict[str, str], optional): + Additional metadata to save with the factors. + + Raises: + AssertionError: + If factors keys don't match `EIGENDECOMPOSITION_FACTOR_NAMES`. + """ assert set(factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) for factor_name in factors: save_path = eigendecomposition_save_path( @@ -65,7 +94,16 @@ def save_eigendecomposition(output_dir: Path, factors: FACTOR_TYPE, metadata: Op def load_eigendecomposition( output_dir: Path, ) -> FACTOR_TYPE: - """Loads Eigendecomposition results from disk.""" + """Loads eigendecomposition results from disk. + + Args: + output_dir (Path): + Directory to load the results from. + + Returns: + FACTOR_TYPE: + Dictionary of loaded eigendecomposition results. + """ eigen_factors = {} for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: save_path = eigendecomposition_save_path( @@ -79,7 +117,16 @@ def load_eigendecomposition( def eigendecomposition_exist( output_dir: Path, ) -> bool: - """Checks if Eigendecomposition results exist at the specified path.""" + """Checks if eigendecomposition results exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for results. + + Returns: + bool: + `True` if all eigendecomposition results exist, `False` otherwise. + """ for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: save_path = eigendecomposition_save_path( output_dir=output_dir, @@ -98,24 +145,24 @@ def perform_eigendecomposition( factor_args: FactorArguments, disable_tqdm: bool = False, ) -> FACTOR_TYPE: - """Performs Eigendecomposition on activation and pseudo-gradient covariance matrices. + """Performs eigendecomposition on activation and pseudo-gradient covariance matrices. Args: covariance_factors (FACTOR_TYPE): - The model used to compute covariance matrices. + Computed covariance factors. model (nn.Module): - The model which contains modules which Eigendecomposition will be performed. + The model used to compute covariance matrices. state (State): The current process's information (e.g., device being used). factor_args (FactorArguments): - Arguments for computing Eigendecomposition. + Arguments for performing eigendecomposition. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: FACTOR_TYPE: - The Eigendecomposition results. These results are organized in nested dictionaries, where the first key - is the name of the factor (e.g., activation eigenvector), and the second key is the module name. + The results are organized in nested dictionaries, where the first key is the name of the factor + (e.g., activation eigenvector), and the second key is the module name. """ eigen_factors: FACTOR_TYPE = {} for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: @@ -148,7 +195,7 @@ def perform_eigendecomposition( device=state.device, dtype=factor_args.eigendecomposition_dtype, ) - # Normalizes covariance matrices. + # Normalize covariance matrices. covariance_matrix.div_(covariance_factors[num_processed_name][module_name].to(device=state.device)) # In cases where covariance matrices are not exactly symmetric due to numerical issues. covariance_matrix = covariance_matrix + covariance_matrix.t() @@ -163,9 +210,14 @@ def perform_eigendecomposition( eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix) else: raise - eigen_factors[eigenvalues_name][module_name] = eigenvalues.to(dtype=original_dtype).contiguous().cpu() - eigen_factors[eigenvectors_name][module_name] = eigenvectors.to(dtype=original_dtype).contiguous().cpu() - del covariance_matrix, eigenvalues, eigenvectors + del covariance_matrix + eigen_factors[eigenvalues_name][module_name] = eigenvalues.contiguous().to( + dtype=original_dtype, device="cpu" + ) + eigen_factors[eigenvectors_name][module_name] = eigenvectors.contiguous().to( + dtype=original_dtype, device="cpu" + ) + del eigenvalues, eigenvectors pbar.update(1) @@ -177,7 +229,24 @@ def lambda_matrices_save_path( factor_name: str, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading Lambda matrices.""" + """Generates the path for saving or loading Lambda matrices. + + Args: + output_dir (Path): + Directory to save the matrices. + factor_name (str): + Name of the factor (must be in `LAMBDA_FACTOR_NAMES`). + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the Lambda matrix file. + + Raises: + AssertionError: + If `factor_name` is not in `LAMBDA_FACTOR_NAMES`. + """ assert factor_name in LAMBDA_FACTOR_NAMES if partition is not None: data_partition, module_partition = partition @@ -193,7 +262,22 @@ def save_lambda_matrices( partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves Lambda matrices to disk.""" + """Saves Lambda matrices to disk. + + Args: + output_dir (Path): + Directory to save the matrices. + factors (FACTOR_TYPE): + Dictionary of factors to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the factors. + + Raises: + AssertionError: + If factors keys don't match `LAMBDA_FACTOR_NAMES`. + """ assert set(factors.keys()) == set(LAMBDA_FACTOR_NAMES) for factor_name in factors: save_path = lambda_matrices_save_path( @@ -208,7 +292,18 @@ def load_lambda_matrices( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> FACTOR_TYPE: - """Loads Lambda matrices from disk.""" + """Loads Lambda matrices from disk. + + Args: + output_dir (Path): + Directory to load the matrices from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + FACTOR_TYPE: + Dictionary of loaded Lambda factors. + """ lambda_factors = {} for factor_name in LAMBDA_FACTOR_NAMES: save_path = lambda_matrices_save_path( @@ -224,7 +319,18 @@ def lambda_matrices_exist( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> bool: - """Checks if Lambda matrices exist at the specified path.""" + """Checks if Lambda matrices exist at the specified directory. + + Args: + output_dir (Path): + Directory to check for matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + bool: + `True` if all Lambda matrices exist, `False` otherwise. + """ for factor_name in LAMBDA_FACTOR_NAMES: save_path = lambda_matrices_save_path( output_dir=output_dir, @@ -246,7 +352,7 @@ def fit_lambda_matrices_with_loader( tracked_module_names: Optional[List[str]] = None, disable_tqdm: bool = False, ) -> Tuple[torch.Tensor, FACTOR_TYPE]: - """Computes Lambda (corrected eigenvalues) matrices for a given model and task. + """Computes Lambda matrices for a given model and task. Args: model (nn.Module): @@ -260,32 +366,31 @@ def fit_lambda_matrices_with_loader( factor_args (FactorArguments): Arguments for computing Lambda matrices. eigen_factors (FACTOR_TYPE, optional): - The eigendecomposition results to use for computing Lambda matrices. + Computed eigendecomposition results. tracked_module_names (List[str], optional): A list of module names for which Lambda matrices will be computed. If not specified, Lambda matrices will be computed for all tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: Tuple[torch.Tensor, FACTOR_TYPE]: - A tuple containing the number of data points processed and computed Lambda matrices. - The Lambda matrices are organized in nested dictionaries, where the first key is the name of - the computed variable and the second key is the module name. + - Number of data points processed. + - Computed Lambda matrices (nested dict: factor_name -> module_name -> tensor). """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - tracked_module_names=tracked_module_names, - mode=ModuleMode.LAMBDA, - keep_factors=False, - ) - if eigen_factors is not None: - for name in eigen_factors: - set_factors(model=model, factor_name=name, factors=eigen_factors[name]) + update_factor_args(model=model, factor_args=factor_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + tracked_module_names=tracked_module_names, + mode=ModuleMode.LAMBDA, + release_memory=True, + ) + if eigen_factors is not None: + for name in eigen_factors: + set_factors(model=model, factor_name=name, factors=eigen_factors[name], clone=True) + del eigen_factors total_steps = 0 num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False) @@ -315,41 +420,41 @@ def fit_lambda_matrices_with_loader( scaler.scale(loss).backward() if factor_args.has_shared_parameters: - # If shared parameter exists, Lambda matrices are computed and updated only after all - # per-sample-gradients are aggregated. - finalize_lambda_matrices(model=model, tracked_module_names=tracked_module_names) + finalize_iteration(model=model, tracked_module_names=tracked_module_names) if ( state.use_distributed - and total_steps % factor_args.distributed_sync_interval == 0 + and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0 and index not in [len(loader) - 1, len(loader) - 2] ): - # Periodically synchronizes all processes to avoid timeout at the final synchronization. state.wait_for_everyone() num_data_processed.add_(find_batch_size(data=batch)) total_steps += 1 pbar.update(1) - with torch.no_grad(): - if state.use_distributed: - # Aggregates Lambda matrices across multiple devices or nodes. - synchronize_lambda_matrices(model=model, tracked_module_names=tracked_module_names) - num_data_processed = num_data_processed.to(device=state.device) - dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) - - saved_factors: FACTOR_TYPE = {} - if state.is_main_process: - for factor_name in LAMBDA_FACTOR_NAMES: - saved_factors[factor_name] = load_factors( - model=model, factor_name=factor_name, tracked_module_names=tracked_module_names, clone=False - ) + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + num_data_processed = num_data_processed.to(device=state.device) + dist.all_reduce(tensor=num_data_processed, op=torch.distributed.ReduceOp.SUM) + num_data_processed = num_data_processed.cpu() - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) - state.wait_for_everyone() + saved_factors: FACTOR_TYPE = {} + for factor_name in LAMBDA_FACTOR_NAMES: + factor = load_factors( + model=model, + factor_name=factor_name, + tracked_module_names=tracked_module_names, + clone=True, + ) + if factor is None: + raise ValueError(f"Factor `{factor_name}` has not been computed.") + saved_factors[factor_name] = factor + + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return num_data_processed, saved_factors diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index 8507ccb..9dfc234 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -4,17 +4,11 @@ import torch.nn.functional as F from einconv.utils import get_conv_paddings from einops import rearrange, reduce -from opt_einsum import DynamicProgramming, contract, contract_expression +from opt_einsum import DynamicProgramming, contract_expression from torch import nn from torch.nn.modules.utils import _pair -from kronfluence.factor.config import FactorConfig from kronfluence.module.tracked_module import TrackedModule -from kronfluence.utils.constants import ( - ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, - PAIRWISE_SCORE_MATRIX_NAME, - SELF_SCORE_VECTOR_NAME, -) from kronfluence.utils.exceptions import UnsupportableModuleError @@ -71,7 +65,7 @@ def extract_patches( class TrackedConv2d(TrackedModule, module_type=nn.Conv2d): - """A tracking wrapper for `nn.Conv2D` modules.""" + """A wrapper for `nn.Conv2d` modules.""" @property def in_channels(self) -> int: # pylint: disable=missing-function-docstring @@ -109,9 +103,7 @@ def weight(self) -> torch.Tensor: # pylint: disable=missing-function-docstring def bias(self) -> Optional[torch.Tensor]: # pylint: disable=missing-function-docstring return self.original_module.bias - def _get_flattened_activation( - self, input_activation: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: input_activation = extract_patches( inputs=input_activation, kernel_size=self.original_module.kernel_size, @@ -136,16 +128,11 @@ def _get_flattened_activation( count = input_activation.size(0) return input_activation, count - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: output_gradient = rearrange(output_gradient, "b c o1 o2 -> (b o1 o2) c") return output_gradient, output_gradient.size(0) - @torch.no_grad() - def _compute_per_sample_gradient( - self, - input_activation: torch.Tensor, - output_gradient: torch.Tensor, - ) -> torch.Tensor: + def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Tensor: input_activation = extract_patches( inputs=input_activation, kernel_size=self.original_module.kernel_size, @@ -167,42 +154,43 @@ def _compute_per_sample_gradient( ], dim=-1, ) + return input_activation + + def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - return torch.einsum("abm,abn->amn", output_gradient, input_activation) - - @torch.no_grad() - def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> None: - input_activation = extract_patches( - inputs=input_activation, - kernel_size=self.original_module.kernel_size, - stride=self.original_module.stride, - padding=self.original_module.padding, - dilation=self.original_module.dilation, - groups=self.original_module.groups, - ) - input_activation = rearrange( - tensor=input_activation, - pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", - ) + summed_gradient = torch.einsum("bci,bco->io", output_gradient, input_activation) + return summed_gradient.view((1, *summed_gradient.size())) - if self.original_module.bias is not None: - input_activation = torch.cat( - [ - input_activation, - input_activation.new_ones((input_activation.size(0), 1), requires_grad=False), - ], - dim=-1, - ) + def compute_per_sample_gradient( + self, + input_activation: torch.Tensor, + output_gradient: torch.Tensor, + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") + per_sample_gradient = torch.einsum("bci,bco->bio", output_gradient, input_activation) + if self.per_sample_gradient_process_fnc is not None: + per_sample_gradient = self.per_sample_gradient_process_fnc( + module_name=self.name, gradient=per_sample_gradient + ) + return per_sample_gradient - if isinstance(self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], list): - left_mat, right_mat = self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] + @torch.no_grad() + def compute_pairwise_score( + self, preconditioned_gradient, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) + output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") - if self._opt_einsum_expression is None: - self._opt_einsum_expression = contract_expression( - "qik,qko,b...i,b...o->qb", + if isinstance(preconditioned_gradient, list): + left_mat, right_mat = preconditioned_gradient + if self.einsum_expression is None: + self.einsum_expression = contract_expression( + "qik,qko,bci,bco->qb", left_mat.shape, right_mat.shape, output_gradient.shape, @@ -211,21 +199,34 @@ def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradien search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" ), ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = self._opt_einsum_expression( - left_mat, right_mat, output_gradient, input_activation - ) - + return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) else: - if self._opt_einsum_expression is None: - self._opt_einsum_expression = contract_expression( + if self.einsum_expression is None: + self.einsum_expression = contract_expression( "qio,bti,bto->qb", - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME].shape, + preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, optimize=DynamicProgramming( search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" ), ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = self._opt_einsum_expression( - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], output_gradient, input_activation + return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + + def compute_self_measurement_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) + output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") + if self.einsum_expression is None: + self.einsum_expression = contract_expression( + "bio,bci,bco->b", + preconditioned_gradient.shape, + output_gradient.shape, + input_activation.shape, + optimize=DynamicProgramming( + search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" + ), ) + return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 7e40ae1..ae7b063 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -2,20 +2,14 @@ import torch from einops import rearrange -from opt_einsum import DynamicProgramming, contract, contract_expression +from opt_einsum import DynamicProgramming, contract_expression from torch import nn -from kronfluence.factor.config import FactorConfig from kronfluence.module.tracked_module import TrackedModule -from kronfluence.utils.constants import ( - ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, - PAIRWISE_SCORE_MATRIX_NAME, - SELF_SCORE_VECTOR_NAME, -) class TrackedLinear(TrackedModule, module_type=nn.Linear): - """A tracking wrapper for `nn.Linear` modules.""" + """A wrapper for `nn.Linear` modules.""" @property def in_features(self) -> int: # pylint: disable=missing-function-docstring @@ -33,15 +27,13 @@ def weight(self) -> torch.Tensor: # pylint: disable=missing-function-docstring def bias(self) -> Optional[torch.Tensor]: # pylint: disable=missing-function-docstring return self.original_module.bias - def _get_flattened_activation( - self, input_activation: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: flattened_activation = rearrange(tensor=input_activation, pattern="b ... d_in -> (b ...) d_in") flattened_attention_mask = None - if self._attention_mask is not None and flattened_activation.size(0) == self._attention_mask.numel(): + if self.attention_mask is not None and flattened_activation.size(0) == self.attention_mask.numel(): # If the binary attention mask is provided, zero-out appropriate activations. - flattened_attention_mask = rearrange(tensor=self._attention_mask, pattern="b ... -> (b ...) 1") + flattened_attention_mask = rearrange(tensor=self.attention_mask, pattern="b ... -> (b ...) 1") flattened_activation.mul_(flattened_attention_mask) if self.original_module.bias is not None: @@ -53,37 +45,47 @@ def _get_flattened_activation( count = flattened_activation.size(0) if flattened_attention_mask is None else flattened_attention_mask.sum() return flattened_activation, count - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: flattened_gradient = rearrange(tensor=output_gradient, pattern="b ... d_out -> (b ...) d_out") - if self._attention_mask is not None and flattened_gradient.size(0) == self._attention_mask.numel(): - count = self._attention_mask.sum() + if self.attention_mask is not None and flattened_gradient.size(0) == self.attention_mask.numel(): + count = self.attention_mask.sum() else: count = flattened_gradient.size(0) return flattened_gradient, count - @torch.no_grad() - def _compute_per_sample_gradient( - self, input_activation: torch.Tensor, output_gradient: torch.Tensor - ) -> torch.Tensor: + def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Tensor: if self.original_module.bias is not None: shape = list(input_activation.size()[:-1]) + [1] append_term = input_activation.new_ones(shape, requires_grad=False) input_activation = torch.cat([input_activation, append_term], dim=-1) - return torch.einsum("b...i,b...o->bio", output_gradient, input_activation) + return input_activation - @torch.no_grad() - def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> None: - if self.original_module.bias is not None: - shape = list(input_activation.size()[:-1]) + [1] - append_term = input_activation.new_ones(shape, requires_grad=False) - input_activation = torch.cat([input_activation, append_term], dim=-1) + def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + summed_gradient = torch.einsum("b...i,b...o->io", output_gradient, input_activation) + return summed_gradient.view((1, *summed_gradient.size())) - if isinstance(self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], list): - left_mat, right_mat = self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] + def compute_per_sample_gradient( + self, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + per_sample_gradient = torch.einsum("b...i,b...o->bio", output_gradient, input_activation) + if self.per_sample_gradient_process_fnc is not None: + per_sample_gradient = self.per_sample_gradient_process_fnc( + module_name=self.name, gradient=per_sample_gradient + ) + return per_sample_gradient + + def compute_pairwise_score( + self, preconditioned_gradient, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + if isinstance(preconditioned_gradient, list): + left_mat, right_mat = preconditioned_gradient - if self._opt_einsum_expression is None: + if self.einsum_expression is None: if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: - self._opt_einsum_expression = contract_expression( + self.einsum_expression = contract_expression( "qik,qko,bti,bto->qbt", left_mat.shape, right_mat.shape, @@ -94,7 +96,7 @@ def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradien ), ) else: - self._opt_einsum_expression = contract_expression( + self.einsum_expression = contract_expression( "qik,qko,b...i,b...o->qb", left_mat.shape, right_mat.shape, @@ -104,16 +106,13 @@ def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradien search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" ), ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = self._opt_einsum_expression( - left_mat, right_mat, output_gradient, input_activation - ) - + return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation) else: - if self._opt_einsum_expression is None: + if self.einsum_expression is None: if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3: - self._opt_einsum_expression = contract_expression( + self.einsum_expression = contract_expression( "qio,bti,bto->qbt", - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME].shape, + preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, optimize=DynamicProgramming( @@ -121,15 +120,29 @@ def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradien ), ) else: - self._opt_einsum_expression = contract_expression( + self.einsum_expression = contract_expression( "qio,b...i,b...o->qb", - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME].shape, + preconditioned_gradient.shape, output_gradient.shape, input_activation.shape, optimize=DynamicProgramming( search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" ), ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = self._opt_einsum_expression( - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], output_gradient, input_activation + return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) + + def compute_self_measurement_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + input_activation = self._flatten_input_activation(input_activation=input_activation) + if self.einsum_expression is None: + self.einsum_expression = contract_expression( + "bio,b...i,b...o->b", + preconditioned_gradient.shape, + output_gradient.shape, + input_activation.shape, + optimize=DynamicProgramming( + search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" + ), ) + return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation) diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index bde63a0..f063498 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -1,39 +1,37 @@ from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import torch -import torch.distributed as dist from accelerate.utils.dataclasses import BaseEnum -from opt_einsum import DynamicProgramming, contract, contract_expression from torch import nn -from torch.utils.hooks import RemovableHandle from kronfluence.arguments import FactorArguments, ScoreArguments from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.module.tracker.factor import CovarianceTracker, LambdaTracker +from kronfluence.module.tracker.gradient import GradientTracker +from kronfluence.module.tracker.pairwise_score import PairwiseScoreTracker +from kronfluence.module.tracker.precondition import PreconditionTracker +from kronfluence.module.tracker.self_score import ( + SelfScoreTracker, + SelfScoreWithMeasurementTracker, +) from kronfluence.utils.constants import ( ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, - ACTIVATION_COVARIANCE_MATRIX_NAME, - ACTIVATION_EIGENVECTORS_NAME, + AGGREGATED_GRADIENT_NAME, COVARIANCE_FACTOR_NAMES, EIGENDECOMPOSITION_FACTOR_NAMES, - GRADIENT_COVARIANCE_MATRIX_NAME, - GRADIENT_EIGENVECTORS_NAME, LAMBDA_FACTOR_NAMES, - LAMBDA_MATRIX_NAME, - NUM_ACTIVATION_COVARIANCE_PROCESSED, - NUM_GRADIENT_COVARIANCE_PROCESSED, - NUM_LAMBDA_PROCESSED, PAIRWISE_SCORE_MATRIX_NAME, PRECONDITIONED_GRADIENT_NAME, PRECONDITIONED_GRADIENT_TYPE, SELF_SCORE_VECTOR_NAME, ) -from kronfluence.utils.exceptions import FactorsNotFoundError class ModuleMode(str, BaseEnum): - """Enum to represent a module's mode, indicating which factors and scores need to be computed - during forward and backward passes.""" + """Enum representing a module's mode, indicating which factors and scores + need to be computed during forward and backward passes.""" DEFAULT = "default" COVARIANCE = "covariance" @@ -46,7 +44,11 @@ class ModuleMode(str, BaseEnum): class TrackedModule(nn.Module): - """A wrapper class for PyTorch modules to compute influence factors and scores.""" + """A wrapper class for PyTorch modules to compute influence factors and scores. + + This class extends `nn.Module` to add functionality for tracking and computing + various influence-related metrics. + """ SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {} @@ -64,7 +66,7 @@ def __init__( score_args: Optional[ScoreArguments] = None, per_sample_gradient_process_fnc: Optional[Callable] = None, ) -> None: - """Initializes an instance of the TrackedModule class. + """Initializes an instance of the `TrackedModule` class. Args: name (str): @@ -76,7 +78,7 @@ def __init__( score_args (ScoreArguments, optional): Arguments for computing influence scores. per_sample_gradient_process_fnc (Callable, optional): - An optional function to post process per-sample-gradient. + Function to post-process per-sample gradients. """ super().__init__() @@ -90,37 +92,57 @@ def __init__( dtype=torch.float16, ) ) + self.current_mode = ModuleMode.DEFAULT self.factor_args = FactorArguments() if factor_args is None else factor_args self.score_args = ScoreArguments() if score_args is None else score_args self.per_sample_gradient_process_fnc = per_sample_gradient_process_fnc - self.aggregated_gradient: Optional[torch.Tensor] = None - - self._cached_activations: Optional[Union[List[torch.Tensor]], torch.Tensor] = None - self._cached_per_sample_gradient: Optional[torch.Tensor] = None - self._opt_einsum_expression: Optional[Callable] = None - self._attention_mask: Optional[torch.Tensor] = None - self._gradient_scale: float = 1.0 - self._registered_hooks: List[RemovableHandle] = [] - self._storage: Dict[str, Optional[Union[torch.Tensor, PRECONDITIONED_GRADIENT_TYPE]]] = {} - self._storage_at_device: bool = False - - # Storage for activation and pseudo-gradient covariance matrices. # + self.einsum_expression = None + + self._trackers = { + ModuleMode.DEFAULT: BaseTracker(self), + ModuleMode.COVARIANCE: CovarianceTracker(self), + ModuleMode.LAMBDA: LambdaTracker(self), + ModuleMode.GRADIENT_AGGREGATION: GradientTracker(self), + ModuleMode.PRECONDITION_GRADIENT: PreconditionTracker(self), + ModuleMode.PAIRWISE_SCORE: PairwiseScoreTracker(self), + ModuleMode.SELF_SCORE: SelfScoreTracker(self), + ModuleMode.SELF_MEASUREMENT_SCORE: SelfScoreWithMeasurementTracker(self), + } + + self.attention_mask: Optional[torch.Tensor] = None + self.gradient_scale: float = 1.0 + self.storage: Dict[str, Optional[Union[torch.Tensor, PRECONDITIONED_GRADIENT_TYPE]]] = {} + + # Storage for activation and pseudo-gradient covariance matrices # for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - self._storage[covariance_factor_name]: Optional[torch.Tensor] = None + self.storage[covariance_factor_name]: Optional[torch.Tensor] = None - # Storage for eigenvectors and eigenvalues. # + # Storage for eigenvectors and eigenvalues # for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: - self._storage[eigen_factor_name]: Optional[torch.Tensor] = None + self.storage[eigen_factor_name]: Optional[torch.Tensor] = None - # Storage for lambda matrices. # + # Storage for lambda matrices # for lambda_factor_name in LAMBDA_FACTOR_NAMES: - self._storage[lambda_factor_name]: Optional[torch.Tensor] = None + self.storage[lambda_factor_name]: Optional[torch.Tensor] = None + + # Storage for preconditioned query gradients and influence scores # + self.storage[AGGREGATED_GRADIENT_NAME]: Optional[torch.Tensor] = None + self.storage[PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None + self.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None + self.storage[PAIRWISE_SCORE_MATRIX_NAME]: Optional[torch.Tensor] = None + self.storage[SELF_SCORE_VECTOR_NAME]: Optional[torch.Tensor] = None + + def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: + """A forward pass of the tracked module. This should have identical behavior to that of the original module.""" + return self.original_module(inputs + self._constant, *args, **kwargs) - # Storage for preconditioned query gradients and influence scores. # - self._storage[PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME]: PRECONDITIONED_GRADIENT_TYPE = None - self._storage[PAIRWISE_SCORE_MATRIX_NAME]: Optional[torch.Tensor] = None - self._storage[SELF_SCORE_VECTOR_NAME]: Optional[torch.Tensor] = None + def prepare_storage(self, device: torch.device) -> None: + """Performs any necessary operations on storage before computing any metrics.""" + FactorConfig.CONFIGS[self.factor_args.strategy].prepare( + storage=self.storage, + score_args=self.score_args, + device=device, + ) def update_factor_args(self, factor_args: FactorArguments) -> None: """Updates the factor arguments.""" @@ -130,76 +152,80 @@ def update_score_args(self, score_args: ScoreArguments) -> None: """Updates the score arguments.""" self.score_args = score_args - def get_factor(self, factor_name: str, clone: bool = False) -> Optional[torch.Tensor]: + def get_factor(self, factor_name: str) -> Optional[torch.Tensor]: """Returns the factor with the given name.""" - if factor_name not in self._storage: + if factor_name not in self.storage or self.storage[factor_name] is None: + return None + return self.storage[factor_name] + + def release_factor(self, factor_name: str) -> None: + """Release the factor with the given name from memory.""" + if factor_name not in self.storage or self.storage[factor_name] is None: return None - return self._storage[factor_name].clone() if clone else self._storage[factor_name] + del self.storage[factor_name] + self.storage[factor_name] = None def set_factor(self, factor_name: str, factor: Any) -> None: """Sets the factor with the given name.""" - if factor_name in self._storage: - self._storage[factor_name] = factor + if factor_name in self.storage: + self.storage[factor_name] = factor - def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: - """A forward pass of the tracked module. This should have identical behavior to that of the original module.""" - return self.original_module(inputs + self._constant, *args, **kwargs) - - def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None: + def set_mode(self, mode: ModuleMode, release_memory: bool = False) -> None: """Sets the module mode of the current `TrackedModule` instance.""" - self.set_attention_mask(attention_mask=None) - self._remove_registered_hooks() - self._opt_einsum_expression = None - - if not keep_factors: - self._release_covariance_matrices() - self._release_eigendecomposition_results() - self._release_lambda_matrix() - self.release_preconditioned_gradient() - self._storage_at_device = False - self.release_scores() + self._trackers[self.current_mode].release_hooks() + self.einsum_expression = None + self.current_mode = mode - if mode == ModuleMode.DEFAULT: - pass - elif mode == ModuleMode.COVARIANCE: - self._register_covariance_hooks() - elif mode == ModuleMode.LAMBDA: - self._register_lambda_hooks() - elif mode == ModuleMode.PRECONDITION_GRADIENT: - self._register_precondition_gradient_hooks() - elif mode == ModuleMode.PAIRWISE_SCORE: - self._register_pairwise_score_hooks() - elif mode == ModuleMode.SELF_SCORE: - self._register_self_score_hooks() - elif mode == ModuleMode.SELF_MEASUREMENT_SCORE: - self._register_self_measurement_score_hooks() - elif mode == ModuleMode.GRADIENT_AGGREGATION: - self._register_gradient_aggregation_hooks() - else: - raise RuntimeError(f"Unknown module mode {mode}.") + if release_memory: + for _mode in self._trackers: + self._trackers[_mode].release_memory() - def _remove_registered_hooks(self) -> None: - """Removes all registered hooks within the module.""" - while self._registered_hooks: - handle = self._registered_hooks.pop() - handle.remove() - self._registered_hooks = [] + self._trackers[self.current_mode].register_hooks() def set_attention_mask(self, attention_mask: Optional[torch.Tensor] = None) -> None: """Sets the attention mask for activation covariance computations.""" - self._attention_mask = attention_mask + self.attention_mask = attention_mask def set_gradient_scale(self, scale: float = 1.0) -> None: """Sets the scale of the gradient obtained from `GradScaler`.""" - self._gradient_scale = scale + self.gradient_scale = scale + + def finalize_iteration(self) -> None: + """Finalizes statistics for the current iteration.""" + self._trackers[self.current_mode].finalize_iteration() + + def exist(self) -> bool: + """Checks if the desired statistics are available.""" + return self._trackers[self.current_mode].exist() + + def synchronize(self, num_processes: int) -> None: + """Synchronizes statistics across multiple processes. + + Args: + num_processes (int): + The number of processes to synchronize across. + """ + self._trackers[self.current_mode].synchronize(num_processes=num_processes) + + def truncate(self, keep_size: int) -> None: + """Truncates stored statistics to a specified size. + + Args: + keep_size (int): + The number of dimension to keep. + """ + self._trackers[self.current_mode].truncate(keep_size=keep_size) + + def accumulate_iterations(self) -> None: + """Accumulates (or prepares to accumulate) statistics across multiple iterations.""" + self._trackers[self.current_mode].accumulate_iterations() + + def finalize_all_iterations(self) -> None: + """Finalizes statistics after all iterations.""" + self._trackers[self.current_mode].finalize_all_iterations() - ############################################## - # Methods for computing covariance matrices. # - ############################################## @abstractmethod - def _get_flattened_activation( - self, input_activation: torch.Tensor - ) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: """Returns the flattened activation tensor and the number of stacked activations. Args: @@ -211,43 +237,10 @@ def _get_flattened_activation( The flattened activation tensor and the number of stacked activations. The flattened activation is a 2-dimensional matrix with dimension `activation_num x activation_dim`. """ - raise NotImplementedError("Subclasses must implement the `_get_flattened_activation` method.") - - def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -> None: - """Computes and updates the activation covariance matrix. - - Args: - input_activation (torch.Tensor): - The input tensor to the module, provided by the PyTorch's forward hook. - """ - input_activation = input_activation.to(dtype=self.factor_args.activation_covariance_dtype) - flattened_activation, count = self._get_flattened_activation(input_activation=input_activation) - - if self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] is None: - dimension = flattened_activation.size(1) - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] = torch.zeros( - size=(dimension, dimension), - dtype=flattened_activation.dtype, - device=flattened_activation.device, - requires_grad=False, - ) - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) - - if self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] is None: - device = None - if isinstance(count, torch.Tensor): - # When using attention masks, `count` can be a tensor. - device = count.device - self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros( - size=(1,), - dtype=torch.int64, - device=device, - requires_grad=False, - ) - self._storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count) + raise NotImplementedError("Subclasses must implement the `get_flattened_activation` method.") @abstractmethod - def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: + def get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, int]]: """Returns the flattened output gradient tensor. Args: @@ -260,115 +253,30 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> Tuple[torch. The flattened output gradient tensor and the number of stacked gradients. The flattened gradient is a 2-dimensional matrix with dimension `gradient_num x gradient_dim`. """ - raise NotImplementedError("Subclasses must implement the `_get_flattened_gradient` method.") + raise NotImplementedError("Subclasses must implement the `get_flattened_gradient` method.") - def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> None: - """Computes and updates the pseudo-gradient covariance matrix. + @abstractmethod + def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor: + """Returns the summed gradient tensor. Args: + input_activation (torch.Tensor): + The input tensor to the module, provided by the PyTorch's forward hook. output_gradient (torch.Tensor): - The gradient tensor with respect to the output of the module, provided by the - PyTorch's backward hook. - """ - output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) - flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient) - - if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None: - dimension = flattened_gradient.size(1) - self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] = torch.zeros( - size=(dimension, dimension), - dtype=flattened_gradient.dtype, - device=flattened_gradient.device, - requires_grad=False, - ) - self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) - - # In most cases, `NUM_GRADIENT_COVARIANCE_PROCESSED` and `NUM_ACTIVATION_COVARIANCE_PROCESSED` are identical. - # However, they may differ when using gradient checkpointing or torch.compile(). - if self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None: - self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( - size=(1,), - dtype=torch.int64, - device=count.device if isinstance(count, torch.Tensor) else None, - requires_grad=False, - ) - self._storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count) - - def _scale_output_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor: - """Scales the output gradient accordingly when gradient scaling is used.""" - if self._gradient_scale != 1.0: - output_gradient = output_gradient.detach() * self._gradient_scale - else: - output_gradient = output_gradient.detach() - return output_gradient - - def _register_covariance_hooks(self) -> None: - """Sets up hooks to compute activation and gradient covariance matrices.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - # Computes and updates activation covariance during forward pass. - self._update_activation_covariance_matrix(input_activation=inputs[0].detach().clone()) - # Registers backward hook to obtain gradient with respect to the output. - outputs.register_hook(backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - # Computes and updates pseudo-gradient covariance during backward pass. - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - self._update_gradient_covariance_matrix(output_gradient=output_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - def _release_covariance_matrices(self) -> None: - """Clears both activation and gradient covariance matrices from memory.""" - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - del self._storage[covariance_factor_name] - self._storage[covariance_factor_name] = None - - def _covariance_matrices_available(self) -> bool: - """Returns `True` if both activation and gradient matrices are available.""" - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - if self._storage[covariance_factor_name] is None: - return False - return True - - @torch.no_grad() - def synchronize_covariance_matrices(self) -> None: - """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and torch.cuda.is_available() and self._covariance_matrices_available(): - # Note that only the main process holds aggregated covariance matrices. - for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - self._storage[covariance_factor_name] = self._storage[covariance_factor_name].cuda() - dist.reduce( - tensor=self._storage[covariance_factor_name], - op=dist.ReduceOp.SUM, - dst=0, - ) - - ########################################## - # Methods for computing Lambda matrices. # - ########################################## - def _release_eigendecomposition_results(self) -> None: - """Clears all eigendecomposition results from memory.""" - for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: - del self._storage[eigen_factor_name] - self._storage[eigen_factor_name] = None + The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. - def _eigendecomposition_results_available(self) -> bool: - """Returns `True` if eigendecomposition results are available.""" - for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: - if self._storage[eigen_factor_name] is None: - return False - return True + Returns: + torch.Tensor: + The aggregated gradient tensor. + """ + raise NotImplementedError("Subclasses must implement the `compute_summed_gradient` method.") @abstractmethod - def _compute_per_sample_gradient( + def compute_per_sample_gradient( self, input_activation: torch.Tensor, output_gradient: torch.Tensor ) -> torch.Tensor: - """Returns the flattened per-sample-gradient tensor. For a brief introduction to - per-sample-gradient, see https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. + """Returns the flattened per-sample gradient tensor. For a brief introduction to + per-sample gradient, see https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. Args: input_activation (torch.Tensor): @@ -378,787 +286,38 @@ def _compute_per_sample_gradient( Returns: torch.Tensor: - The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x gradient_dim x activation_dim`. - """ - raise NotImplementedError("Subclasses must implement the `_compute_per_sample_gradient` method.") - - def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: - """Computes and updates the Lambda matrix using the provided per-sample-gradient. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. - """ - per_sample_gradient = per_sample_gradient.to(self.factor_args.lambda_dtype) - batch_size = per_sample_gradient.size(0) - - if self._storage[LAMBDA_MATRIX_NAME] is None: - # Initializes the Lambda matrix if it does not exist. - self._storage[LAMBDA_MATRIX_NAME] = torch.zeros( - size=(per_sample_gradient.size(1), per_sample_gradient.size(2)), - dtype=per_sample_gradient.dtype, - device=per_sample_gradient.device, - requires_grad=False, - ) - - if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: - if not self._eigendecomposition_results_available(): - error_msg = ( - f"The strategy {self.factor_args.strategy} requires Eigendecomposition " - f"results to be loaded for Lambda computations. However, Eigendecomposition " - f"results are not found." - ) - raise FactorsNotFoundError(error_msg) - # Moves activation and pseudo-gradient eigenvectors to appropriate devices. - self._storage[ACTIVATION_EIGENVECTORS_NAME] = self._storage[ACTIVATION_EIGENVECTORS_NAME].to( - dtype=self.factor_args.lambda_dtype, - device=per_sample_gradient.device, - ) - self._storage[GRADIENT_EIGENVECTORS_NAME] = self._storage[GRADIENT_EIGENVECTORS_NAME].to( - dtype=self.factor_args.lambda_dtype, - device=per_sample_gradient.device, - ) - - if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: - if self.factor_args.use_iterative_lambda_aggregation: - # This batch-wise iterative update can be useful when the GPU memory is limited. - per_sample_gradient = torch.matmul( - per_sample_gradient, - self._storage[ACTIVATION_EIGENVECTORS_NAME], - ) - for i in range(batch_size): - sqrt_lambda = torch.matmul( - self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - per_sample_gradient[i], - ) - self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_()) - else: - per_sample_gradient = torch.matmul( - self._storage[GRADIENT_EIGENVECTORS_NAME].t(), - torch.matmul(per_sample_gradient, self._storage[ACTIVATION_EIGENVECTORS_NAME]), - ) - self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) - else: - # Approximates the eigenbasis as identity. - self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) - - if self._storage[NUM_LAMBDA_PROCESSED] is None: - self._storage[NUM_LAMBDA_PROCESSED] = torch.zeros( - size=(1,), - dtype=torch.int64, - device=None, - requires_grad=False, - ) - self._storage[NUM_LAMBDA_PROCESSED].add_(batch_size) - - def _register_lambda_hooks(self) -> None: - """Sets up hooks to compute Lambda matrices.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.factor_args.per_sample_gradient_dtype) - if self.factor_args.offload_activations_to_cpu: - cached_activation = cached_activation.cpu() - - if self.factor_args.has_shared_parameters: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - if self._cached_activations is None: - raise RuntimeError( - f"Module '{self.name}' encountered multiple times in forward pass. " - "Set 'has_shared_parameters=True' to allow parameter sharing." - ) - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.factor_args.per_sample_gradient_dtype), - ).to(dtype=self.factor_args.lambda_dtype) - del self._cached_activations - self._cached_activations = None - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - self._update_lambda_matrix(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.factor_args.per_sample_gradient_dtype), - ) - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - def _clear_per_sample_gradient_cache(self) -> None: - """Clears all caches from per-sample-gradient computations.""" - del self._cached_per_sample_gradient - self._cached_per_sample_gradient = None - del self._cached_activations - self._cached_activations = None - - @torch.no_grad() - def finalize_lambda_matrix(self) -> None: - """Updates Lambda matrix using cached per-sample gradients.""" - self._update_lambda_matrix( - per_sample_gradient=self._cached_per_sample_gradient.to(dtype=self.factor_args.lambda_dtype) - ) - self._clear_per_sample_gradient_cache() - - def _release_lambda_matrix(self) -> None: - """Clears all Lambda matrices from memory.""" - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - del self._storage[lambda_factor_name] - self._storage[lambda_factor_name] = None - self._clear_per_sample_gradient_cache() - - def _lambda_matrix_available(self) -> bool: - """Returns `True` if Lambda matrices are available.""" - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - if self._storage[lambda_factor_name] is None: - return False - return True - - @torch.no_grad() - def synchronize_lambda_matrices(self) -> None: - """Aggregates Lambda matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and torch.cuda.is_available() and self._lambda_matrix_available(): - for lambda_factor_name in LAMBDA_FACTOR_NAMES: - self._storage[lambda_factor_name] = self._storage[lambda_factor_name].cuda() - dist.reduce( - tensor=self._storage[lambda_factor_name], - op=dist.ReduceOp.SUM, - dst=0, - ) - - ################################################## - # Methods for computing preconditioned gradient. # - ################################################## - def _compute_low_rank_preconditioned_gradient( - self, - preconditioned_gradient: torch.Tensor, - ) -> List[torch.Tensor]: - """Performs low-rank approximation of the preconditioned gradient with SVD. - - Args: - preconditioned_gradient (torch.Tensor): - The preconditioned per-sample-gradient matrix to be low-rank approximated. - - Returns: - List[torch.Tensor, torch.Tensor]: - Low-rank matrices that approximate the original preconditioned query gradient. - """ - rank = self.score_args.query_gradient_low_rank - if self.score_args.use_full_svd: - U, S, V = torch.linalg.svd( # pylint: disable=not-callable - preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - # Avoids holding the full memory of the original tensor before indexing. - V_k = V[:, :rank, :].contiguous().clone() - return [ - torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=self.score_args.score_dtype).contiguous().clone(), - V_k.to(dtype=self.score_args.score_dtype), - ] - U, S, V = torch.svd_lowrank( - preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype), - q=rank, - ) - return [ - torch.matmul(U, torch.diag_embed(S)).to(dtype=self.score_args.score_dtype).contiguous().clone(), - V.transpose(1, 2).contiguous().to(dtype=self.score_args.score_dtype), - ] - - def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the preconditioned per-sample-gradient. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. + The per-sample gradient tensor. """ - preconditioned_gradient = FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( - gradient=per_sample_gradient.to(dtype=self.score_args.precondition_dtype), - storage=self._storage, - damping=self.score_args.damping_factor, - ) - del per_sample_gradient - - if ( - self.score_args.query_gradient_low_rank is not None - and min(preconditioned_gradient.size()[1:]) > self.score_args.query_gradient_low_rank - ): - # Applies low-rank approximation to the preconditioned gradient. - preconditioned_gradient = self._compute_low_rank_preconditioned_gradient( - preconditioned_gradient=preconditioned_gradient - ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient - else: - self._storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient.to(dtype=self.score_args.score_dtype) - - def _register_precondition_gradient_hooks(self) -> None: - """Sets up hooks to compute preconditioned per-sample-gradient.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.offload_activations_to_cpu: - cached_activation = cached_activation.cpu() - - if self.factor_args.has_shared_parameters: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.precondition_dtype) - del self._cached_activations - self._cached_activations = None - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - self._compute_preconditioned_gradient(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.precondition_dtype) - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_preconditioned_gradient(self) -> None: - """Computes preconditioned per-sample-gradient using cached per-sample gradients.""" - self._compute_preconditioned_gradient(per_sample_gradient=self._cached_per_sample_gradient) - self._clear_per_sample_gradient_cache() - - @torch.no_grad() - def accumulate_preconditioned_gradient(self) -> None: - """Accumulates the preconditioned per-sample-gradients computed over different batches.""" - accumulated_gradient = self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] - gradient = self._storage[PRECONDITIONED_GRADIENT_NAME] - - if self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] is None: - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = gradient - else: - if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = [ - torch.cat((accumulated_gradient[0], gradient[0]), dim=0).contiguous(), - torch.cat((accumulated_gradient[1], gradient[1]), dim=0).contiguous(), - ] - else: - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat( - (accumulated_gradient, gradient), dim=0 - ).contiguous() - del self._storage[PRECONDITIONED_GRADIENT_NAME] - self._storage[PRECONDITIONED_GRADIENT_NAME] = None - - def release_preconditioned_gradient(self) -> None: - """Clears preconditioned per-sample-gradient from memory.""" - del self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None - del self._storage[PRECONDITIONED_GRADIENT_NAME] - self._storage[PRECONDITIONED_GRADIENT_NAME] = None - self._clear_per_sample_gradient_cache() + raise NotImplementedError("Subclasses must implement the `compute_per_sample_gradient` method.") - @torch.no_grad() - def truncate_preconditioned_gradient(self, keep_size: int) -> None: - """Truncates and keeps only the first `keep_size` dimension for preconditioned gradient.""" - if self._storage[PRECONDITIONED_GRADIENT_NAME] is not None: - if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - assert len(self._storage[PRECONDITIONED_GRADIENT_NAME]) == 2 - self._storage[PRECONDITIONED_GRADIENT_NAME] = [ - self._storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size].clone(), - self._storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size].clone(), - ] - else: - self._storage[PRECONDITIONED_GRADIENT_NAME] = self._storage[PRECONDITIONED_GRADIENT_NAME][ - :keep_size - ].clone() - - @torch.no_grad() - def synchronize_preconditioned_gradient(self, num_processes: int) -> None: - """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" - if ( - dist.is_initialized() - and torch.cuda.is_available() - and self._storage[PRECONDITIONED_GRADIENT_NAME] is not None - ): - if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): - for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])): - size = self._storage[PRECONDITIONED_GRADIENT_NAME][i].size() - stacked_matrix = torch.empty( - size=(num_processes, size[0], size[1], size[2]), - dtype=self._storage[PRECONDITIONED_GRADIENT_NAME][i].dtype, - device=self._storage[PRECONDITIONED_GRADIENT_NAME][i].device, - ) - torch.distributed.all_gather_into_tensor( - output_tensor=stacked_matrix, - input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), - ) - self._storage[PRECONDITIONED_GRADIENT_NAME][i] = ( - stacked_matrix.transpose(0, 1) - .reshape(num_processes * size[0], size[1], size[2]) - .contiguous() - .clone() - ) - else: - size = self._storage[PRECONDITIONED_GRADIENT_NAME].size() - stacked_preconditioned_gradient = torch.empty( - size=(num_processes, size[0], size[1], size[2]), - dtype=self._storage[PRECONDITIONED_GRADIENT_NAME].dtype, - device=self._storage[PRECONDITIONED_GRADIENT_NAME].device, - ) - torch.distributed.all_gather_into_tensor( - output_tensor=stacked_preconditioned_gradient, - input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), - ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = ( - stacked_preconditioned_gradient.transpose(0, 1) - .reshape(num_processes * size[0], size[1], size[2]) - .contiguous() - .clone() - ) - - #################################################### - # Methods for computing pairwise influence scores. # - #################################################### @abstractmethod - def _compute_pairwise_score(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> None: + def compute_pairwise_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: """Computes pairwise influence scores. Args: + preconditioned_gradient (torch.Tensor): + The preconditioned gradient. input_activation (torch.Tensor): The input tensor to the module, provided by the PyTorch's forward hook. output_gradient (torch.Tensor): The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. """ - raise NotImplementedError("Subclasses must implement the `_compute_pairwise_score` method.") - - def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tensor) -> None: - """Computes pairwise influence scores using per-sample-gradient. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. - """ - per_sample_gradient = per_sample_gradient.to(dtype=self.score_args.score_dtype) - if isinstance(self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], list): - left_mat, right_mat = self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] - if self._opt_einsum_expression is None: - self._opt_einsum_expression = contract_expression( - "qki,toi,qok->qt", - right_mat.shape, - per_sample_gradient.shape, - left_mat.shape, - optimize=DynamicProgramming( - search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops" - ), - ) - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = self._opt_einsum_expression( - right_mat, per_sample_gradient, left_mat - ) - else: - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract( - "qio,tio->qt", - self._storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], - per_sample_gradient, - ) - - def _register_pairwise_score_hooks(self) -> None: - """Sets up hooks to compute pairwise influence scores.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.score_dtype) - if self.score_args.offload_activations_to_cpu: - cached_activation = cached_activation.cpu() - - if self.factor_args.has_shared_parameters: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - if self.per_sample_gradient_process_fnc is None: - self._compute_pairwise_score( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.score_dtype), - ) - del self._cached_activations - self._cached_activations = None - else: - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.score_dtype), - ) - del self._cached_activations - self._cached_activations = None - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - cached_activation = self._cached_activations.pop() - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.score_dtype), - ) - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_pairwise_score(self) -> None: - """Computes pairwise influence scores using cached per-sample-gradient.""" - self._compute_pairwise_score_with_gradient(per_sample_gradient=self._cached_per_sample_gradient) - self._clear_per_sample_gradient_cache() - - #################################################### - # Methods for aggregating gradients. # - #################################################### - def _register_gradient_aggregation_hooks(self) -> None: - """Sets up hooks to aggregate gradients.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.offload_activations_to_cpu: - cached_activation = cached_activation.cpu() - - if self.factor_args.has_shared_parameters: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ) - del self._cached_activations - self._cached_activations = None - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - if self.aggregated_gradient is None: - self.aggregated_gradient = torch.zeros( - size=(1, per_sample_gradient.size(1), per_sample_gradient.size(2)), - dtype=per_sample_gradient.dtype, - device=per_sample_gradient.device, - requires_grad=False, - ) - self.aggregated_gradient.add_(per_sample_gradient.sum(dim=0, keepdim=True)) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - summed_gradient = per_sample_gradient.sum(dim=0, keepdim=True) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = torch.zeros_like(summed_gradient, requires_grad=False) - self._cached_per_sample_gradient.add_(summed_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_gradient_aggregation(self) -> None: - """Computes aggregated gradients using cached gradients.""" - if self.aggregated_gradient is None: - self.aggregated_gradient = torch.zeros_like(self._cached_per_sample_gradient, requires_grad=False) - self.aggregated_gradient.add_(self._cached_per_sample_gradient) - self._clear_per_sample_gradient_cache() - - @torch.no_grad() - def synchronize_aggregated_gradient(self) -> None: - """Aggregates aggregated gradient across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and torch.cuda.is_available(): - if self.aggregated_gradient is None: - self.aggregated_gradient = torch.zeros( - size=(1,), - dtype=self.score_args.per_sample_gradient_dtype, - device="cuda", - requires_grad=False, - ) - dist.all_reduce( - tensor=self.aggregated_gradient, - op=dist.ReduceOp.SUM, - ) - - def release_aggregated_gradient(self) -> None: - """Clears aggregated gradient from memory.""" - del self.aggregated_gradient - self.aggregated_gradient = None - self._clear_per_sample_gradient_cache() - - @torch.no_grad() - def compute_preconditioned_gradient_from_aggregation(self) -> None: - """Computes preconditioned gradient using cached gradients.""" - self._compute_preconditioned_gradient( - per_sample_gradient=self.aggregated_gradient.to(dtype=self.score_args.precondition_dtype) - ) - - @torch.no_grad() - def compute_pairwise_scores_from_aggregation(self) -> None: - """Computes pairwise influence scores using cached gradients.""" - self._compute_pairwise_score_with_gradient( - per_sample_gradient=self.aggregated_gradient.to(dtype=self.score_args.score_dtype) - ) - - #################################################### - # Methods for computing self-influence scores. # - #################################################### - def _compute_self_score(self, per_sample_gradient: torch.Tensor) -> None: - """Computes self-influence scores using per-sample-gradients. - - Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. - """ - if not self._storage_at_device: - self._move_storage_to_device( - target_device=per_sample_gradient.device, target_dtype=self.score_args.precondition_dtype - ) - self._storage_at_device = True - preconditioned_gradient = ( - FactorConfig.CONFIGS[self.factor_args.strategy] - .precondition_gradient( - gradient=per_sample_gradient.to(dtype=self.score_args.precondition_dtype), - storage=self._storage, - damping=self.score_args.damping_factor, - ) - .to(dtype=self.score_args.score_dtype) - ) - preconditioned_gradient.mul_(per_sample_gradient.to(dtype=self.score_args.score_dtype)) - self._storage[SELF_SCORE_VECTOR_NAME] = preconditioned_gradient.sum(dim=(1, 2)) - - def _register_self_score_hooks(self) -> None: - """Installs forward and backward hooks for computation of self-influence scores.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.offload_activations_to_cpu: - cached_activation = cached_activation.cpu() - - if self.factor_args.has_shared_parameters: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.precondition_dtype) - del self._cached_activations - self._cached_activations = None - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - self._compute_self_score(per_sample_gradient=per_sample_gradient) + raise NotImplementedError("Subclasses must implement the `compute_pairwise_score` method.") - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - def finalize_self_score(self) -> None: - """Computes the self-influence scores using the cached per-sample-gradient.""" - self._compute_self_score(per_sample_gradient=self._cached_per_sample_gradient) - self._clear_per_sample_gradient_cache() - - def _compute_self_measurement_score(self, per_sample_gradient: torch.Tensor) -> None: - """Computes the self-influence scores with measurement. + @abstractmethod + def compute_self_measurement_score( + self, preconditioned_gradient: torch.Tensor, input_activation: torch.Tensor, output_gradient: torch.Tensor + ) -> torch.Tensor: + """Computes self-influence scores with measurement. Args: - per_sample_gradient (torch.Tensor): - The per-sample-gradient tensor for the given batch. + preconditioned_gradient (torch.Tensor): + The preconditioned gradient. + input_activation (torch.Tensor): + The input tensor to the module, provided by the PyTorch's forward hook. + output_gradient (torch.Tensor): + The gradient tensor with respect to the output of the module, provided by the PyTorch's backward hook. """ - per_sample_gradient = per_sample_gradient.to(dtype=self.score_args.score_dtype) - if not self._storage_at_device: - self._move_storage_to_device( - target_device=per_sample_gradient.device, target_dtype=self.score_args.precondition_dtype - ) - self._storage_at_device = True - self._storage[SELF_SCORE_VECTOR_NAME] = per_sample_gradient.mul_( - self._storage[PRECONDITIONED_GRADIENT_NAME] - ).sum(dim=(1, 2)) - del self._storage[PRECONDITIONED_GRADIENT_NAME] - self._storage[PRECONDITIONED_GRADIENT_NAME] = None - - def _register_self_measurement_score_hooks(self) -> None: - """Installs forward and backward hooks for computation of self-influence scores.""" - - @torch.no_grad() - def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: - del module - cached_activation = inputs[0].detach().clone().to(dtype=self.score_args.per_sample_gradient_dtype) - if self.score_args.offload_activations_to_cpu: - cached_activation = cached_activation.cpu() - - if self.factor_args.has_shared_parameters: - if self._cached_activations is None: - self._cached_activations = [] - self._cached_activations.append(cached_activation) - else: - self._cached_activations = cached_activation - - outputs.register_hook(shared_backward_hook if self.factor_args.has_shared_parameters else backward_hook) - - @torch.no_grad() - def backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=self._cached_activations.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ).to(dtype=self.score_args.score_dtype) - del self._cached_activations - self._cached_activations = None - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - self._compute_self_measurement_score(per_sample_gradient=per_sample_gradient) - - @torch.no_grad() - def shared_backward_hook(output_gradient: torch.Tensor) -> None: - output_gradient = self._scale_output_gradient(output_gradient=output_gradient) - cached_activation = self._cached_activations.pop() - per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to(device=output_gradient.device), - output_gradient=output_gradient.to(dtype=self.score_args.per_sample_gradient_dtype), - ) - if self.per_sample_gradient_process_fnc is not None: - per_sample_gradient = self.per_sample_gradient_process_fnc( - module_name=self.name, gradient=per_sample_gradient - ) - if self._cached_per_sample_gradient is None: - self._cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) - self._cached_per_sample_gradient.add_(per_sample_gradient) - - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) - - @torch.no_grad() - def finalize_self_measurement_score(self) -> None: - """Computes the self-influence scores with measurement using the cached per-sample-gradient.""" - self._compute_self_measurement_score( - per_sample_gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.score_dtype) - ) - self._clear_per_sample_gradient_cache() - - def _move_storage_to_device(self, target_device: torch.device, target_dtype: torch.dtype) -> None: - """Moves stored factors into the target device.""" - for name, factor in self._storage.items(): - if factor is not None: - if isinstance(factor, list): - for i in range(len(self._storage[name])): - self._storage[name][i] = factor[i].to( - device=target_device, - dtype=target_dtype, - ) - else: - self._storage[name] = factor.to(device=target_device, dtype=target_dtype) - - def release_scores(self) -> None: - """Clears the influence scores from memory.""" - del self._storage[PAIRWISE_SCORE_MATRIX_NAME] - self._storage[PAIRWISE_SCORE_MATRIX_NAME] = None - del self._storage[SELF_SCORE_VECTOR_NAME] - self._storage[SELF_SCORE_VECTOR_NAME] = None - self._clear_per_sample_gradient_cache() + raise NotImplementedError("Subclasses must implement the `compute_self_measurement_score` method.") diff --git a/kronfluence/module/tracker/__init__.py b/kronfluence/module/tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kronfluence/module/tracker/base.py b/kronfluence/module/tracker/base.py new file mode 100644 index 0000000..24181be --- /dev/null +++ b/kronfluence/module/tracker/base.py @@ -0,0 +1,99 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from torch.utils.hooks import RemovableHandle + + +class BaseTracker: + """Base class for tracking module activations, gradients, and scores.""" + + def __init__(self, module: nn.Module) -> None: + """Initializes an instance of the `BaseTracker` class. + + Args: + module (TrackedModule): + The `TrackedModule` that wraps the original module. + """ + self.module = module + self.registered_hooks: List[RemovableHandle] = [] + self.cached_activations: Optional[Union[List[torch.Tensor]], torch.Tensor] = None + self.cached_per_sample_gradient: Optional[torch.Tensor] = None + + def release_hooks(self) -> None: + """Removes all registered hooks.""" + while self.registered_hooks: + handle = self.registered_hooks.pop() + handle.remove() + self.registered_hooks = [] + + def clear_all_cache(self) -> None: + """Clears all cached data from memory.""" + del self.cached_activations, self.cached_per_sample_gradient + self.cached_activations, self.cached_per_sample_gradient = None, None + + def _scale_output_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + """Scales the output gradient and convert to the target dtype. + + Args: + output_gradient (torch.Tensor): + The output gradient to scale. + target_dtype (torch.dtype): + The desired dtype for the output. + + Returns: + torch.Tensor: + The scaled gradient in the target dtype. + """ + original_dtype = output_gradient.dtype + output_gradient = output_gradient.detach().to(dtype=target_dtype) + if self.module.gradient_scale != 1.0: + if original_dtype != target_dtype: + output_gradient.mul_(self.module.gradient_scale) + else: + output_gradient = output_gradient * self.module.gradient_scale + return output_gradient + + def _raise_cache_not_found_exception(self) -> None: + """Raises an exception when cached activations are not found.""" + raise RuntimeError( + f"Module '{self.module.name}' has no cached activations. This can occur if:\n" + f"1. The module was not used during the forward pass, or\n" + f"2. The module was encountered multiple times in the forward pass.\n" + f"For case 2, set 'has_shared_parameters=True' to enable parameter sharing." + ) + + def register_hooks(self) -> None: + """Registers hooks for the module.""" + + def finalize_iteration(self) -> None: + """Finalizes statistics for the current iteration.""" + + def exist(self) -> bool: + """Checks if the desired statistics are available.""" + return False + + def synchronize(self, num_processes: int) -> None: + """Synchronizes statistics across multiple processes. + + Args: + num_processes (int): + The number of processes to synchronize across. + """ + + def truncate(self, keep_size: int) -> None: + """Truncates stored statistics to a specified size. + + Args: + keep_size (int): + The number of dimension to keep. + """ + + def accumulate_iterations(self) -> None: + """Accumulates (or prepares to accumulate) statistics across multiple iterations.""" + + def finalize_all_iterations(self) -> None: + """Finalizes statistics after all iterations.""" + + def release_memory(self) -> None: + """Releases any memory held by the tracker.""" diff --git a/kronfluence/module/tracker/factor.py b/kronfluence/module/tracker/factor.py new file mode 100644 index 0000000..ccb5c09 --- /dev/null +++ b/kronfluence/module/tracker/factor.py @@ -0,0 +1,293 @@ +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn + +from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACTIVATION_COVARIANCE_MATRIX_NAME, + ACTIVATION_EIGENVECTORS_NAME, + COVARIANCE_FACTOR_NAMES, + EIGENDECOMPOSITION_FACTOR_NAMES, + GRADIENT_COVARIANCE_MATRIX_NAME, + GRADIENT_EIGENVECTORS_NAME, + LAMBDA_FACTOR_NAMES, + LAMBDA_MATRIX_NAME, + NUM_ACTIVATION_COVARIANCE_PROCESSED, + NUM_GRADIENT_COVARIANCE_PROCESSED, + NUM_LAMBDA_PROCESSED, +) +from kronfluence.utils.exceptions import FactorsNotFoundError + + +class CovarianceTracker(BaseTracker): + """Tracks and computes activation and gradient covariance matrices for a given module.""" + + def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -> None: + """Computes and updates the activation covariance matrix. + + Args: + input_activation (torch.Tensor): + The input tensor to the module, provided by PyTorch's forward hook. + """ + flattened_activation, count = self.module.get_flattened_activation(input_activation=input_activation) + + if self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] is None: + self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + device=count.device if isinstance(count, torch.Tensor) else None, + requires_grad=False, + ) + dimension = flattened_activation.size(1) + self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME] = torch.zeros( + size=(dimension, dimension), + dtype=flattened_activation.dtype, + device=flattened_activation.device, + requires_grad=False, + ) + self.module.storage[NUM_ACTIVATION_COVARIANCE_PROCESSED].add_(count) + self.module.storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) + + def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> None: + """Computes and updates the pseudo-gradient covariance matrix. + + Args: + output_gradient (torch.Tensor): + The gradient tensor with respect to the output of the module, provided by PyTorch's backward hook. + """ + flattened_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient) + + if self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] is None: + # In most cases, `NUM_GRADIENT_COVARIANCE_PROCESSED` and `NUM_ACTIVATION_COVARIANCE_PROCESSED` are + # identical. However, they may differ when using gradient checkpointing or torch.compile(). + self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + device=count.device if isinstance(count, torch.Tensor) else None, + requires_grad=False, + ) + dimension = flattened_gradient.size(1) + self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME] = torch.zeros( + size=(dimension, dimension), + dtype=flattened_gradient.dtype, + device=flattened_gradient.device, + requires_grad=False, + ) + self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count) + self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) + + def register_hooks(self) -> None: + """Sets up hooks to compute activation and gradient covariance matrices.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + # Computes and updates activation covariance during forward pass. + input_activation = inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype) + self._update_activation_covariance_matrix(input_activation=input_activation) + outputs.register_hook(backward_hook) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + # Computes and updates pseudo-gradient covariance during backward pass. + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.factor_args.gradient_covariance_dtype + ) + self._update_gradient_covariance_matrix(output_gradient=output_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + def exist(self) -> bool: + """Checks if both activation and gradient covariance matrices are available.""" + for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + if self.module.storage[covariance_factor_name] is None: + return False + return True + + def synchronize(self, num_processes: int) -> None: + """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" + del num_processes + if dist.is_initialized() and torch.cuda.is_available() and self.exist(): + for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + self.module.storage[covariance_factor_name] = self.module.storage[covariance_factor_name].cuda() + dist.reduce( + tensor=self.module.storage[covariance_factor_name], + op=dist.ReduceOp.SUM, + dst=0, + ) + + def release_memory(self) -> None: + """Clears all covariance matrices from memory.""" + for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + del self.module.storage[covariance_factor_name] + self.module.storage[covariance_factor_name] = None + + +class LambdaTracker(BaseTracker): + """Tracks and computes Lambda matrices for a given module.""" + + def _eigendecomposition_results_exist(self) -> bool: + """Checks if eigendecomposition results are available.""" + for eigen_factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: + if self.module.storage[eigen_factor_name] is None: + return False + return True + + def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: + """Computes and updates the Lambda matrix using provided per-sample gradient. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample gradient tensor for the given batch. + """ + batch_size = per_sample_gradient.size(0) + + if self.module.storage[NUM_LAMBDA_PROCESSED] is None: + self.module.storage[NUM_LAMBDA_PROCESSED] = torch.zeros( + size=(1,), + dtype=torch.int64, + device=None, + requires_grad=False, + ) + self.module.storage[LAMBDA_MATRIX_NAME] = torch.zeros( + size=(per_sample_gradient.size(1), per_sample_gradient.size(2)), + dtype=per_sample_gradient.dtype, + device=per_sample_gradient.device, + requires_grad=False, + ) + + if FactorConfig.CONFIGS[self.module.factor_args.strategy].requires_eigendecomposition_for_lambda: + if not self._eigendecomposition_results_exist(): + raise FactorsNotFoundError( + f"The strategy {self.module.factor_args.strategy} requires eigendecomposition " + f"results for Lambda computations, but they are not found." + ) + + # Move activation and pseudo-gradient eigenvectors to appropriate devices. + self.module.storage[ACTIVATION_EIGENVECTORS_NAME] = self.module.storage[ + ACTIVATION_EIGENVECTORS_NAME + ].to( + dtype=per_sample_gradient.dtype, + device=per_sample_gradient.device, + ) + self.module.storage[GRADIENT_EIGENVECTORS_NAME] = self.module.storage[GRADIENT_EIGENVECTORS_NAME].to( + dtype=per_sample_gradient.dtype, + device=per_sample_gradient.device, + ) + + self.module.storage[NUM_LAMBDA_PROCESSED].add_(batch_size) + if FactorConfig.CONFIGS[self.module.factor_args.strategy].requires_eigendecomposition_for_lambda: + if self.module.factor_args.use_iterative_lambda_aggregation: + # This batch-wise iterative update can be useful when the GPU memory is limited. + per_sample_gradient = torch.matmul( + per_sample_gradient, + self.module.storage[ACTIVATION_EIGENVECTORS_NAME], + ) + for i in range(batch_size): + sqrt_lambda = torch.matmul( + self.module.storage[GRADIENT_EIGENVECTORS_NAME].t(), + per_sample_gradient[i], + ) + self.module.storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_()) + else: + per_sample_gradient = torch.matmul( + self.module.storage[GRADIENT_EIGENVECTORS_NAME].t(), + torch.matmul(per_sample_gradient, self.module.storage[ACTIVATION_EIGENVECTORS_NAME]), + ) + self.module.storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) + else: + # Approximate the eigenbasis as identity. + self.module.storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) + + def register_hooks(self) -> None: + """Sets up hooks to compute lambda matrices.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.factor_args.per_sample_gradient_dtype, + copy=True, + ) + + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.factor_args.per_sample_gradient_dtype + ) + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).to(dtype=self.module.factor_args.lambda_dtype) + self.clear_all_cache() + self._update_lambda_matrix(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.factor_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Updates Lambda matrix using cached per-sample gradients.""" + self.cached_per_sample_gradient = self.cached_per_sample_gradient.to(dtype=self.module.factor_args.lambda_dtype) + if self.module.factor_args.has_shared_parameters: + self._update_lambda_matrix(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if Lambda matrices are available.""" + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + if self.module.storage[lambda_factor_name] is None: + return False + return True + + def synchronize(self, num_processes: int) -> None: + """Aggregates Lambda matrices across multiple devices or nodes in a distributed setting.""" + del num_processes + if dist.is_initialized() and torch.cuda.is_available() and self.exist(): + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + self.module.storage[lambda_factor_name] = self.module.storage[lambda_factor_name].cuda() + dist.reduce( + tensor=self.module.storage[lambda_factor_name], + op=dist.ReduceOp.SUM, + dst=0, + ) + + def release_memory(self) -> None: + """Clears Lambda matrices from memory.""" + self.clear_all_cache() + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + del self.module.storage[lambda_factor_name] + self.module.storage[lambda_factor_name] = None diff --git a/kronfluence/module/tracker/gradient.py b/kronfluence/module/tracker/gradient.py new file mode 100644 index 0000000..e88f835 --- /dev/null +++ b/kronfluence/module/tracker/gradient.py @@ -0,0 +1,126 @@ +from typing import List, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn + +from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, + AGGREGATED_GRADIENT_NAME, + PRECONDITIONED_GRADIENT_NAME, +) + + +class GradientTracker(BaseTracker): + """Tracks and computes summed gradient for a given module.""" + + def register_hooks(self) -> None: + """Sets up hooks to compute and keep track of summed gradient.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + if self.module.per_sample_gradient_process_fnc is None: + summed_gradient = self.module.compute_summed_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + else: + summed_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).sum(dim=0, keepdim=True) + self.clear_all_cache() + + if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False) + self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + if self.module.per_sample_gradient_process_fnc is None: + summed_gradient = self.module.compute_summed_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + else: + summed_gradient = self.module.comute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ).sum(dim=0, keepdim=True) + + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(summed_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(summed_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + def exist(self) -> bool: + return self.module.storage[AGGREGATED_GRADIENT_NAME] is not None + + @torch.no_grad() + def finalize_iteration(self): + """Computes preconditioned gradient using cached per-sample gradients.""" + if not self.module.factor_args.has_shared_parameters: + return + if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like( + self.cached_per_sample_gradient, requires_grad=False + ) + self.module.storage[AGGREGATED_GRADIENT_NAME].add_(self.cached_per_sample_gradient) + self.clear_all_cache() + + def release_memory(self) -> None: + """Clears summed gradients from memory.""" + del self.module.storage[AGGREGATED_GRADIENT_NAME] + self.module.storage[AGGREGATED_GRADIENT_NAME] = None + self.clear_all_cache() + + def synchronize(self, num_processes: int = 1) -> None: + """Aggregates summed gradient across multiple devices or nodes in a distributed setting.""" + del num_processes + if dist.is_initialized() and torch.cuda.is_available(): + if self.module.storage[AGGREGATED_GRADIENT_NAME] is None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros( + size=(1,), + dtype=self.module.score_args.per_sample_gradient_dtype, + device="cuda", + requires_grad=False, + ) + dist.all_reduce( + tensor=self.module.storage[AGGREGATED_GRADIENT_NAME], + op=dist.ReduceOp.SUM, + ) diff --git a/kronfluence/module/tracker/pairwise_score.py b/kronfluence/module/tracker/pairwise_score.py new file mode 100644 index 0000000..029afc4 --- /dev/null +++ b/kronfluence/module/tracker/pairwise_score.py @@ -0,0 +1,149 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +from opt_einsum import DynamicProgramming, contract_expression + +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, + AGGREGATED_GRADIENT_NAME, + PAIRWISE_SCORE_MATRIX_NAME, + PRECONDITIONED_GRADIENT_NAME, +) + + +class PairwiseScoreTracker(BaseTracker): + """Computes pairwise influence scores for a given module.""" + + def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tensor) -> None: + """Computes pairwise influence scores using per-sample-gradient. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample-gradient tensor for the given batch. + """ + if isinstance(self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], list): + left_mat, right_mat = self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] + if self.module.einsum_expression is None: + self.module.einsum_expression = contract_expression( + "qki,toi,qok->qt", + right_mat.shape, + per_sample_gradient.shape, + left_mat.shape, + optimize=DynamicProgramming( + search_outer=True, minimize="size" if self.module.score_args.einsum_minimize_size else "flops" + ), + ) + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = self.module.einsum_expression( + right_mat, per_sample_gradient, left_mat + ) + else: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = torch.einsum( + "qio,tio->qt", + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], + per_sample_gradient, + ) + + def register_hooks(self) -> None: + """Sets up hooks to compute pairwise influence scores.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype + ) + if self.module.per_sample_gradient_process_fnc is None: + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = self.module.compute_pairwise_score( + preconditioned_gradient=self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + self.clear_all_cache() + else: + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + self.clear_all_cache() + self._compute_pairwise_score_with_gradient(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Computes pairwise influence scores using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self._compute_pairwise_score_with_gradient(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if pairwise score is available.""" + return self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] is not None + + def accumulate_iterations(self) -> None: + """Removes pairwise scores from memory after a single iteration.""" + self.release_memory() + + def finalize_all_iterations(self) -> None: + """Removes cached preconditioned gradient from memory. Additionally, if aggregated gradients are available, + computes the pairwise score using them.""" + if self.module.storage[AGGREGATED_GRADIENT_NAME] is not None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].to( + dtype=self.module.score_args.precondition_dtype + ) + self._compute_pairwise_score_with_gradient( + per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME] + ) + del ( + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], + self.module.storage[PRECONDITIONED_GRADIENT_NAME], + ) + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + self.clear_all_cache() + + def release_memory(self) -> None: + """Releases pairwise scores from memory.""" + self.clear_all_cache() + del self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] + self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] = None diff --git a/kronfluence/module/tracker/precondition.py b/kronfluence/module/tracker/precondition.py new file mode 100644 index 0000000..f9ccd60 --- /dev/null +++ b/kronfluence/module/tracker/precondition.py @@ -0,0 +1,248 @@ +from typing import List, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn + +from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, + AGGREGATED_GRADIENT_NAME, + PRECONDITIONED_GRADIENT_NAME, +) + + +class PreconditionTracker(BaseTracker): + """Computes preconditioned gradient for a given module.""" + + def _compute_low_rank_preconditioned_gradient( + self, + preconditioned_gradient: torch.Tensor, + target_dtype: torch.dtype, + ) -> List[torch.Tensor]: + """Performs low-rank approximation of the preconditioned gradient. + + Args: + preconditioned_gradient (torch.Tensor): + The preconditioned per-sample gradient tensor to be low-rank approximated. + target_dtype (torch.dtype): + The desired dtype for the output. + + Returns: + List[torch.Tensor, torch.Tensor]: + Low-rank matrices approximating the original preconditioned gradient. + """ + rank = self.module.score_args.query_gradient_low_rank + if self.module.score_args.use_full_svd: + U, S, V = torch.linalg.svd( # pylint: disable=not-callable + preconditioned_gradient, + full_matrices=False, + ) + U_k = U[:, :, :rank] + S_k = S[:, :rank] + # Avoid holding the full memory of the original tensor before indexing. + V_k = V[:, :rank, :].to(dtype=target_dtype, copy=True) + left_mat = torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=target_dtype) + return [left_mat, V_k] + + U, S, V = torch.svd_lowrank(preconditioned_gradient, q=rank) + left_mat = torch.matmul(U, torch.diag_embed(S)).to(dtype=target_dtype) + V = V.transpose(1, 2).to(dtype=target_dtype) + return [left_mat, V] + + def _compute_preconditioned_gradient(self, per_sample_gradient: torch.Tensor) -> None: + """Computes the preconditioned per-sample gradient. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample-gradient tensor for the given batch. + """ + preconditioned_gradient = FactorConfig.CONFIGS[self.module.factor_args.strategy].precondition_gradient( + gradient=per_sample_gradient, + storage=self.module.storage, + ) + del per_sample_gradient + + if ( + self.module.score_args.query_gradient_low_rank is not None + and min(preconditioned_gradient.size()[1:]) > self.module.score_args.query_gradient_low_rank + ): + # Apply low-rank approximation to the preconditioned gradient. + preconditioned_gradient = preconditioned_gradient.to( + dtype=self.module.score_args.query_gradient_svd_dtype + ).contiguous() + preconditioned_gradient = self._compute_low_rank_preconditioned_gradient( + preconditioned_gradient=preconditioned_gradient, + target_dtype=self.module.score_args.score_dtype, + ) + else: + preconditioned_gradient = preconditioned_gradient.to(dtype=self.module.score_args.score_dtype) + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient + + def register_hooks(self) -> None: + """Sets up hooks to compute preconditioned per-sample gradient.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).to(dtype=self.module.score_args.precondition_dtype) + self.clear_all_cache() + self._compute_preconditioned_gradient(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Computes preconditioned gradient using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self.cached_per_sample_gradient = self.cached_per_sample_gradient.to( + dtype=self.module.score_args.precondition_dtype + ) + self._compute_preconditioned_gradient(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if preconditioned gradient is available.""" + return self.module.storage[PRECONDITIONED_GRADIENT_NAME] is not None + + def synchronize(self, num_processes: int = 1) -> None: + """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" + if ( + dist.is_initialized() + and torch.cuda.is_available() + and self.module.storage[PRECONDITIONED_GRADIENT_NAME] is not None + ): + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + for i in range(len(self.module.storage[PRECONDITIONED_GRADIENT_NAME])): + size = self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].size() + stacked_matrix = torch.empty( + size=(num_processes, size[0], size[1], size[2]), + dtype=self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].dtype, + device=self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].device, + ) + torch.distributed.all_gather_into_tensor( + output_tensor=stacked_matrix, + input_tensor=self.module.storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), + ) + self.module.storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).reshape( + num_processes * size[0], size[1], size[2] + ) + else: + size = self.module.storage[PRECONDITIONED_GRADIENT_NAME].size() + stacked_preconditioned_gradient = torch.empty( + size=(num_processes, size[0], size[1], size[2]), + dtype=self.module.storage[PRECONDITIONED_GRADIENT_NAME].dtype, + device=self.module.storage[PRECONDITIONED_GRADIENT_NAME].device, + ) + torch.distributed.all_gather_into_tensor( + output_tensor=stacked_preconditioned_gradient, + input_tensor=self.module.storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), + ) + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose( + 0, 1 + ).reshape(num_processes * size[0], size[1], size[2]) + + def truncate(self, keep_size: int) -> None: + """Truncates preconditioned gradient to appropriate dimension.""" + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + assert len(self.module.storage[PRECONDITIONED_GRADIENT_NAME]) == 2 + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = [ + self.module.storage[PRECONDITIONED_GRADIENT_NAME][0][:keep_size], + self.module.storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size], + ] + else: + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = self.module.storage[PRECONDITIONED_GRADIENT_NAME][ + :keep_size + ] + + def accumulate_iterations(self) -> None: + """Accumulates preconditioned gradient across multiple iterations.""" + accumulated_gradient = self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] + gradient = self.module.storage[PRECONDITIONED_GRADIENT_NAME] + + if self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] is None: + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = [ + tensor.contiguous() for tensor in gradient + ] + else: + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = gradient.contiguous() + + else: + if isinstance(self.module.storage[PRECONDITIONED_GRADIENT_NAME], list): + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = [ + torch.cat((accumulated_gradient[0], gradient[0]), dim=0).contiguous(), + torch.cat((accumulated_gradient[1], gradient[1]), dim=0).contiguous(), + ] + else: + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = torch.cat( + (accumulated_gradient, gradient), dim=0 + ).contiguous() + del self.module.storage[PRECONDITIONED_GRADIENT_NAME] + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + + def finalize_all_iterations(self) -> None: + """Preconditions aggregated gradient if it exists in storage.""" + if self.module.storage[AGGREGATED_GRADIENT_NAME] is not None: + self.module.storage[AGGREGATED_GRADIENT_NAME] = self.module.storage[AGGREGATED_GRADIENT_NAME].to( + dtype=self.module.score_args.precondition_dtype + ) + self._compute_preconditioned_gradient(per_sample_gradient=self.module.storage[AGGREGATED_GRADIENT_NAME]) + del self.module.storage[AGGREGATED_GRADIENT_NAME] + self.module.storage[AGGREGATED_GRADIENT_NAME] = None + self.accumulate_iterations() + + def release_memory(self) -> None: + """Clears preconditioned gradients from memory.""" + del ( + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME], + self.module.storage[PRECONDITIONED_GRADIENT_NAME], + ) + self.module.storage[ACCUMULATED_PRECONDITIONED_GRADIENT_NAME] = None + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + self.clear_all_cache() diff --git a/kronfluence/module/tracker/self_score.py b/kronfluence/module/tracker/self_score.py new file mode 100644 index 0000000..1d84175 --- /dev/null +++ b/kronfluence/module/tracker/self_score.py @@ -0,0 +1,251 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +from opt_einsum import DynamicProgramming, contract_expression + +from kronfluence.factor.config import FactorConfig +from kronfluence.module.tracker.base import BaseTracker +from kronfluence.utils.constants import ( + ACCUMULATED_PRECONDITIONED_GRADIENT_NAME, + AGGREGATED_GRADIENT_NAME, + PAIRWISE_SCORE_MATRIX_NAME, + PRECONDITIONED_GRADIENT_NAME, + SELF_SCORE_VECTOR_NAME, +) + + +def move_storage_to_device(storage, target_device: torch.device) -> None: + """Moves stored factors into the target device.""" + for name, factor in storage.items(): + if factor is not None: + if isinstance(factor, list): + for i in range(len(storage[name])): + storage[name][i] = factor[i].to(device=target_device) + else: + storage[name] = factor.to(device=target_device) + + +class SelfScoreTracker(BaseTracker): + """Computes self-influence scores for a given module.""" + + storage_at_device: bool = False + + def _compute_self_score(self, per_sample_gradient: torch.Tensor) -> None: + """Computes self-influence scores using per-sample gradients. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample gradient tensor for the given batch. + """ + if not self.storage_at_device: + move_storage_to_device( + storage=self.module.storage, + target_device=per_sample_gradient.device, + ) + self.storage_at_device = True + + preconditioned_gradient = ( + FactorConfig.CONFIGS[self.module.factor_args.strategy] + .precondition_gradient( + gradient=per_sample_gradient, + storage=self.module.storage, + ) + .to(dtype=self.module.score_args.score_dtype) + ) + per_sample_gradient = per_sample_gradient.to(dtype=self.module.score_args.score_dtype) + preconditioned_gradient.mul_(per_sample_gradient) + self.module.storage[SELF_SCORE_VECTOR_NAME] = preconditioned_gradient.sum(dim=(1, 2)) + + def register_hooks(self) -> None: + """Sets up hooks to compute self-influence scores.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.per_sample_gradient_dtype, + copy=True, + ) + + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ).to(dtype=self.module.score_args.precondition_dtype) + self.clear_all_cache() + self._compute_self_score(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Computes self-influence scores using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self.cached_per_sample_gradient = self.cached_per_sample_gradient.to( + dtype=self.module.score_args.precondition_dtype + ) + self._compute_self_score(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if self-influence score is available.""" + return self.module.storage[SELF_SCORE_VECTOR_NAME] is not None + + def accumulate_iterations(self) -> None: + """Removes self-scores from memory after a single iteration.""" + self.release_memory() + + def release_memory(self) -> None: + """Releases pairwise scores from memory.""" + self.clear_all_cache() + del self.module.storage[SELF_SCORE_VECTOR_NAME] + self.module.storage[SELF_SCORE_VECTOR_NAME] = None + + +class SelfScoreWithMeasurementTracker(BaseTracker): + """Computes self-influence scores with measurement for a given module.""" + + storage_at_device: bool = False + + def _compute_self_measurement_score_with_gradient(self, per_sample_gradient: torch.Tensor) -> None: + """Computes self-influence scores using per-sample-gradients. + + Args: + per_sample_gradient (torch.Tensor): + The per-sample-gradient tensor for the given batch. + """ + self.module.storage[SELF_SCORE_VECTOR_NAME] = per_sample_gradient.mul_( + self.module.storage[PRECONDITIONED_GRADIENT_NAME] + ).sum(dim=(1, 2)) + del self.module.storage[PRECONDITIONED_GRADIENT_NAME] + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + + def register_hooks(self) -> None: + """Sets up hooks to compute pairwise influence scores.""" + + @torch.no_grad() + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None: + del module + cached_activation = inputs[0].detach() + device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device + cached_activation = cached_activation.to( + device=device, + dtype=self.module.score_args.score_dtype, + copy=True, + ) + + if self.module.factor_args.has_shared_parameters: + if self.cached_activations is None: + self.cached_activations = [] + self.cached_activations.append(cached_activation) + else: + self.cached_activations = cached_activation + + outputs.register_hook( + shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook + ) + + @torch.no_grad() + def backward_hook(output_gradient: torch.Tensor) -> None: + if self.cached_activations is None: + self._raise_cache_not_found_exception() + + if not self.storage_at_device: + move_storage_to_device( + storage=self.module.storage, + target_device=output_gradient.device, + ) + self.storage_at_device = True + + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype + ) + if self.module.per_sample_gradient_process_fnc is None: + self.module.storage[SELF_SCORE_VECTOR_NAME] = self.module.compute_self_measurement_score( + preconditioned_gradient=self.module.storage[PRECONDITIONED_GRADIENT_NAME], + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + del self.module.storage[PRECONDITIONED_GRADIENT_NAME] + self.module.storage[PRECONDITIONED_GRADIENT_NAME] = None + self.clear_all_cache() + else: + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=self.cached_activations.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + self.clear_all_cache() + self._compute_self_measurement_score_with_gradient(per_sample_gradient=per_sample_gradient) + + @torch.no_grad() + def shared_backward_hook(output_gradient: torch.Tensor) -> None: + output_gradient = self._scale_output_gradient( + output_gradient=output_gradient, target_dtype=self.module.score_args.score_dtype + ) + cached_activation = self.cached_activations.pop() + per_sample_gradient = self.module.compute_per_sample_gradient( + input_activation=cached_activation.to(device=output_gradient.device), + output_gradient=output_gradient, + ) + if self.cached_per_sample_gradient is None: + self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False) + self.cached_per_sample_gradient.add_(per_sample_gradient) + + self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook)) + + @torch.no_grad() + def finalize_iteration(self) -> None: + """Computes pairwise influence scores using cached per-sample gradients.""" + if self.module.factor_args.has_shared_parameters: + self._compute_self_measurement_score_with_gradient(per_sample_gradient=self.cached_per_sample_gradient) + self.clear_all_cache() + + def exist(self) -> bool: + """Checks if pairwise score is available.""" + return self.module.storage[SELF_SCORE_VECTOR_NAME] is not None + + def accumulate_iterations(self) -> None: + """Removes pairwise scores from memory after a single iteration.""" + self.release_memory() + + def release_memory(self) -> None: + """Releases pairwise scores from memory.""" + self.clear_all_cache() + del self.module.storage[SELF_SCORE_VECTOR_NAME] + self.module.storage[SELF_SCORE_VECTOR_NAME] = None diff --git a/kronfluence/module/utils.py b/kronfluence/module/utils.py index 7e6e4bc..87bec90 100644 --- a/kronfluence/module/utils.py +++ b/kronfluence/module/utils.py @@ -49,12 +49,12 @@ def wrap_tracked_modules( ) tracked_module_count = 0 - tracked_module_names = task.tracked_modules() if task is not None else None + tracked_module_names = task.get_influence_tracked_modules() if task is not None else None tracked_module_exists_dict = None if tracked_module_names is not None: tracked_module_exists_dict = {name: False for name in tracked_module_names} per_sample_gradient_process_fnc = None - if task is not None and task.do_post_process_per_sample_gradient: + if task is not None and task.enable_post_process_per_sample_gradient: per_sample_gradient_process_fnc = task.post_process_per_sample_gradient named_modules = model.named_modules() @@ -62,7 +62,7 @@ def wrap_tracked_modules( if len(list(module.children())) > 0: continue - # Filters modules based on the task's `tracked_modules` if specified. + # Filters modules based on the task's `get_influence_tracked_modules` if specified. if tracked_module_names is not None and module_name not in tracked_module_names: continue @@ -115,7 +115,7 @@ def set_mode( model: nn.Module, mode: ModuleMode, tracked_module_names: List[str] = None, - keep_factors: bool = False, + release_memory: bool = False, ) -> None: """Sets the module mode of all `TrackedModule` instances within a model. For example, to compute and update covariance matrices, the module mode needs to be set to `ModuleMode.COVARIANCE`. If @@ -129,14 +129,14 @@ def set_mode( tracked_module_names (List[str], optional): The list of names for `TrackedModule` to set the new mode. If not provided, the new mode is set for all available `TrackedModule` within the model. - keep_factors (bool, optional): - If True, existing factors are kept in memory. Defaults to False. + release_memory (bool, optional): + If `False`, existing factors are kept in memory. """ for module in model.modules(): if isinstance(module, TrackedModule): if tracked_module_names is not None and module.name not in tracked_module_names: continue - module.set_mode(mode=mode, keep_factors=keep_factors) + module.set_mode(mode=mode, release_memory=release_memory) def update_factor_args(model: nn.Module, factor_args: FactorArguments) -> None: @@ -177,15 +177,21 @@ def load_factors( continue factor = module.get_factor(factor_name=factor_name) if factor is not None: - loaded_factors[module.name] = factor.contiguous().clone() if clone else factor + if clone: + loaded_factors[module.name] = factor.clone(memory_format=torch.contiguous_format) + module.release_factor(factor_name=factor_name) + else: + loaded_factors[module.name] = factor return loaded_factors -def set_factors(model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor]) -> None: +def set_factors(model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor], clone: bool = False) -> None: """Sets new factor for all `TrackedModule` instances within a model.""" for module in model.modules(): if isinstance(module, TrackedModule): - module.set_factor(factor_name=factor_name, factor=factors[module.name]) + module.set_factor( + factor_name=factor_name, factor=factors[module.name].clone() if clone else factors[module.name] + ) def set_attention_mask( @@ -202,6 +208,8 @@ def set_attention_mask( module.set_attention_mask(attention_mask=None) elif isinstance(attention_mask, torch.Tensor): module.set_attention_mask(attention_mask=attention_mask) + elif attention_mask is None: + module.set_attention_mask(attention_mask=None) else: raise RuntimeError(f"Invalid attention mask `{attention_mask}` provided.") @@ -216,25 +224,50 @@ def set_gradient_scale( module.set_gradient_scale(scale=gradient_scale) -def synchronize_covariance_matrices(model: nn.Module, tracked_module_names: List[str]) -> None: - """Synchronizes covariance matrices for all modules listed in `tracked_module_names`.""" +def prepare_modules(model: nn.Module, tracked_module_names: List[str], device: torch.device) -> None: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.prepare_storage(device=device) + + +def synchronize_modules(model: nn.Module, tracked_module_names: List[str], num_processes: int = 1) -> None: for module in model.modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.synchronize_covariance_matrices() + module.synchronize(num_processes=num_processes) -def finalize_lambda_matrices(model: nn.Module, tracked_module_names: List[str]) -> None: +def truncate(model: nn.Module, tracked_module_names: List[str], keep_size: int) -> None: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.truncate(keep_size=keep_size) + + +def exist_for_all_modules(model: nn.Module, tracked_module_names: List[str]) -> bool: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if not module.exist(): + return False + return True + + +def accumulate_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + module.accumulate_iterations() + + +def finalize_iteration(model: nn.Module, tracked_module_names: List[str]) -> None: """Updates Lambda matrices for all modules listed in `tracked_module_names`.""" for name, module in model.named_modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.finalize_lambda_matrix() + module.finalize_iteration() -def synchronize_lambda_matrices(model: nn.Module, tracked_module_names: List[str]) -> None: - """Synchronizes Lambda matrices for all modules listed in `tracked_module_names`.""" - for module in model.modules(): +def finalize_all_iterations(model: nn.Module, tracked_module_names: List[str]) -> None: + """Updates Lambda matrices for all modules listed in `tracked_module_names`.""" + for name, module in model.named_modules(): if isinstance(module, TrackedModule) and module.name in tracked_module_names: - module.synchronize_lambda_matrices() + module.finalize_all_iterations() def finalize_preconditioned_gradient(model: nn.Module, tracked_module_names: List[str]) -> None: diff --git a/kronfluence/score/__init__.py b/kronfluence/score/__init__.py index 49a4267..41ac691 100644 --- a/kronfluence/score/__init__.py +++ b/kronfluence/score/__init__.py @@ -1,5 +1,5 @@ from .pairwise import ( - _compute_dot_products_with_loader, + compute_dot_products_with_loader, compute_pairwise_scores_with_loaders, load_pairwise_scores, pairwise_scores_exist, diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py new file mode 100644 index 0000000..23c1774 --- /dev/null +++ b/kronfluence/score/dot_product.py @@ -0,0 +1,246 @@ +from typing import Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from accelerate.utils import send_to_device +from torch import autocast, nn +from torch.cuda.amp import GradScaler +from torch.utils import data +from tqdm import tqdm + +from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.module import TrackedModule +from kronfluence.module.tracked_module import ModuleMode +from kronfluence.module.utils import ( + accumulate_iterations, + exist_for_all_modules, + finalize_all_iterations, + finalize_iteration, + set_mode, + synchronize_modules, +) +from kronfluence.task import Task +from kronfluence.utils.constants import ( + ALL_MODULE_NAME, + DISTRIBUTED_SYNC_INTERVAL, + PAIRWISE_SCORE_MATRIX_NAME, + SCORE_TYPE, +) +from kronfluence.utils.logger import TQDM_BAR_FORMAT +from kronfluence.utils.state import State, no_sync, release_memory + +DIMENSION_NOT_MATCH_ERROR_MSG = ( + "The model does not support token-wise score computation. " + "Set `compute_per_module_scores=True` or `compute_per_token_scores=False` " + "to avoid this error." +) + + +def compute_dot_products_with_loader( + model: nn.Module, + task: Task, + state: State, + train_loader: data.DataLoader, + factor_args: FactorArguments, + score_args: ScoreArguments, + tracked_module_names: List[str], + scaler: GradScaler, + disable_tqdm: bool = False, +) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + """After computing the preconditioned query gradient, compute dot products with individual training gradients.""" + model.zero_grad(set_to_none=True) + set_mode( + model=model, + mode=ModuleMode.PAIRWISE_SCORE, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + dataset_size = len(train_loader.dataset) + score_chunks: Dict[str, List[torch.Tensor]] = {} + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + score_chunks[module.name] = [] + else: + score_chunks[ALL_MODULE_NAME] = [] + + total_steps = 0 + enable_amp = score_args.amp_dtype is not None + + with tqdm( + total=len(train_loader), + desc="Computing pairwise scores (training gradient)", + bar_format=TQDM_BAR_FORMAT, + disable=not state.is_main_process or disable_tqdm, + ) as pbar: + for batch in train_loader: + batch = send_to_device(tensor=batch, device=state.device) + + with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) + with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): + loss = task.compute_train_loss( + batch=batch, + model=model, + sample=False, + ) + scaler.scale(loss).backward() + + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + score_chunks[module.name].append( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() + ) + else: + pairwise_scores = None + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) + raise + score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: + state.wait_for_everyone() + + total_steps += 1 + pbar.update(1) + + model.zero_grad(set_to_none=True) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + total_scores: SCORE_TYPE = {} + for module_name, chunks in score_chunks.items(): + total_scores[module_name] = torch.cat(chunks, dim=1) + if state.use_distributed: + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + torch.distributed.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=1)[:, :dataset_size].cpu() + state.wait_for_everyone() + + return total_scores + + +def compute_aggregated_dot_products_with_loader( + model: nn.Module, + task: Task, + state: State, + train_loader: data.DataLoader, + factor_args: FactorArguments, + score_args: ScoreArguments, + tracked_module_names: List[str], + scaler: GradScaler, + disable_tqdm: bool = False, +) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + """After computing the preconditioned query gradient, compute dot products with aggregated training gradients.""" + model.zero_grad(set_to_none=True) + set_mode( + model=model, + mode=ModuleMode.GRADIENT_AGGREGATION, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + scores: Dict[str, Optional[torch.Tensor]] = {} + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + scores[module.name] = None + else: + scores[ALL_MODULE_NAME] = None + + enable_amp = score_args.amp_dtype is not None + + if not exist_for_all_modules(model=model, tracked_module_names=tracked_module_names): + with tqdm( + total=len(train_loader), + desc="Computing pairwise scores (training gradient)", + bar_format=TQDM_BAR_FORMAT, + disable=not state.is_main_process or disable_tqdm, + ) as pbar: + for batch in train_loader: + batch = send_to_device(tensor=batch, device=state.device) + + with no_sync(model=model, state=state): + model.zero_grad(set_to_none=True) + with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): + loss = task.compute_train_loss( + batch=batch, + model=model, + sample=False, + ) + scaler.scale(loss).backward() + + if factor_args.has_shared_parameters: + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + pbar.update(1) + + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) + + set_mode( + model=model, + mode=ModuleMode.PAIRWISE_SCORE, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + scores[module.name] = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() + else: + pairwise_scores = None + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if pairwise_scores is None: + pairwise_scores = torch.zeros_like( + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False + ) + try: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) + except RuntimeError: + if score_args.compute_per_token_scores: + raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) + raise + scores[ALL_MODULE_NAME] = pairwise_scores.cpu() + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + model.zero_grad(set_to_none=True) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + release_memory() + + return scores diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index 086d1f7..8dffcbc 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Union import torch -import torch.distributed as dist from accelerate.utils import send_to_device from safetensors.torch import load_file, save_file from torch import autocast, nn @@ -11,37 +10,27 @@ from tqdm import tqdm from kronfluence.arguments import FactorArguments, ScoreArguments -from kronfluence.module import TrackedModule from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( - accumulate_preconditioned_gradient, - aggregated_gradient_exist, - compute_pairwise_scores_from_aggregation, - compute_preconditioned_gradient_from_aggregation, - finalize_gradient_aggregation, - finalize_pairwise_scores, - finalize_preconditioned_gradient, + accumulate_iterations, + finalize_all_iterations, + finalize_iteration, get_tracked_module_names, - release_aggregated_gradient, - release_preconditioned_gradient, - release_scores, + prepare_modules, set_factors, set_gradient_scale, set_mode, - synchronize_aggregated_gradient, - synchronize_preconditioned_gradient, - truncate_preconditioned_gradient, + synchronize_modules, + truncate, update_factor_args, update_score_args, ) -from kronfluence.task import Task -from kronfluence.utils.constants import ( - ALL_MODULE_NAME, - FACTOR_TYPE, - PAIRWISE_SCORE_MATRIX_NAME, - PARTITION_TYPE, - SCORE_TYPE, +from kronfluence.score.dot_product import ( + compute_aggregated_dot_products_with_loader, + compute_dot_products_with_loader, ) +from kronfluence.task import Task +from kronfluence.utils.constants import FACTOR_TYPE, PARTITION_TYPE, SCORE_TYPE from kronfluence.utils.logger import TQDM_BAR_FORMAT from kronfluence.utils.state import State, no_sync, release_memory @@ -50,7 +39,18 @@ def pairwise_scores_save_path( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> Path: - """Generates the path for saving/loading pairwise scores.""" + """Generates the path for saving or loading pairwise influence scores. + + Args: + output_dir (Path): + Directory to save the matrices. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + Path: + The full path for the score file. + """ if partition is not None: data_partition, module_partition = partition return output_dir / ( @@ -59,25 +59,24 @@ def pairwise_scores_save_path( return output_dir / "pairwise_scores.safetensors" -def pairwise_scores_exist( - output_dir: Path, - partition: Optional[PARTITION_TYPE] = None, -) -> bool: - """Checks if the pairwise scores exist at the specified path.""" - save_path = pairwise_scores_save_path( - output_dir=output_dir, - partition=partition, - ) - return save_path.exists() - - def save_pairwise_scores( output_dir: Path, scores: SCORE_TYPE, partition: Optional[PARTITION_TYPE] = None, metadata: Optional[Dict[str, str]] = None, ) -> None: - """Saves pairwise influence scores to disk.""" + """Saves pairwise scores to disk. + + Args: + output_dir (Path): + Directory to save the scores. + scores (FACTOR_TYPE): + Dictionary of scores to save. + partition (PARTITION_TYPE, optional): + Partition information, if any. + metadata (Dict[str, str], optional): + Additional metadata to save with the scores. + """ save_path = pairwise_scores_save_path( output_dir=output_dir, partition=partition, @@ -89,7 +88,18 @@ def load_pairwise_scores( output_dir: Path, partition: Optional[PARTITION_TYPE] = None, ) -> Dict[str, torch.Tensor]: - """Loads pairwise scores from disk.""" + """Loads pairwise scores from disk. + + Args: + output_dir (Path): + Directory to load the scores from. + partition (PARTITION_TYPE, optional): + Partition information, if any. + + Returns: + FACTOR_TYPE: + Dictionary of loaded scores. + """ save_path = pairwise_scores_save_path( output_dir=output_dir, partition=partition, @@ -97,216 +107,27 @@ def load_pairwise_scores( return load_file(filename=save_path) -def _compute_dot_products_with_loader( - model: nn.Module, - task: Task, - state: State, - train_loader: data.DataLoader, - factor_args: FactorArguments, - score_args: ScoreArguments, - tracked_module_names: List[str], - scaler: GradScaler, - disable_tqdm: bool = False, -) -> Union[Dict[str, torch.Tensor], torch.Tensor]: - """After computing the preconditioned query gradient, compute dot products with individual training gradients.""" - with torch.no_grad(): - model.zero_grad(set_to_none=True) - set_mode( - model=model, - mode=ModuleMode.PAIRWISE_SCORE, - tracked_module_names=tracked_module_names, - keep_factors=True, - ) - release_memory() - - dataset_size = len(train_loader.dataset) - score_chunks: Dict[str, List[torch.Tensor]] = {} - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name] = [] - else: - score_chunks[ALL_MODULE_NAME] = [] - - total_steps = 0 - enable_amp = score_args.amp_dtype is not None - - with tqdm( - total=len(train_loader), - desc="Computing pairwise scores (training gradient)", - bar_format=TQDM_BAR_FORMAT, - disable=not state.is_main_process or disable_tqdm, - ) as pbar: - for batch in train_loader: - batch = send_to_device(tensor=batch, device=state.device) - - with no_sync(model=model, state=state): - model.zero_grad(set_to_none=True) - with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): - loss = task.compute_train_loss( - batch=batch, - model=model, - sample=False, - ) - scaler.scale(loss).backward() - - if factor_args.has_shared_parameters: - finalize_pairwise_scores(model=model, tracked_module_names=tracked_module_names) - - with torch.no_grad(): - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() - ) - else: - pairwise_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if pairwise_scores is None: - pairwise_scores = torch.zeros_like( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False - ) - try: - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) - except RuntimeError: - if score_args.compute_per_token_scores: - raise RuntimeError( - "The model does not support token-wise score computation. " - "Set `compute_per_module_scores=True` or `compute_per_token_scores=False` " - "to avoid this error." - ) - raise - score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) - release_scores(model=model) - - if state.use_distributed and total_steps % score_args.distributed_sync_interval == 0: - # Periodically synchronizes all processes to avoid timeout at the final synchronization. - state.wait_for_everyone() - - total_steps += 1 - pbar.update(1) - - with torch.no_grad(): - model.zero_grad(set_to_none=True) - set_mode( - model=model, - mode=ModuleMode.PRECONDITION_GRADIENT, - tracked_module_names=tracked_module_names, - keep_factors=True, - ) - release_preconditioned_gradient(model=model) - release_memory() - - total_scores: SCORE_TYPE = {} - for module_name, chunks in score_chunks.items(): - total_scores[module_name] = torch.cat(chunks, dim=1) - if state.use_distributed: - total_scores[module_name] = total_scores[module_name].to(device=state.device) - gather_list = None - if state.is_main_process: - gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) - if state.is_main_process: - total_scores[module_name] = torch.cat(gather_list, dim=1)[:, :dataset_size].cpu() - state.wait_for_everyone() - - return total_scores - - -def _compute_aggregated_dot_products_with_loader( - model: nn.Module, - task: Task, - state: State, - train_loader: data.DataLoader, - factor_args: FactorArguments, - score_args: ScoreArguments, - tracked_module_names: List[str], - scaler: GradScaler, - disable_tqdm: bool = False, -) -> Union[Dict[str, torch.Tensor], torch.Tensor]: - """After computing the preconditioned query gradient, compute dot products with aggregated training gradients.""" - with torch.no_grad(): - model.zero_grad(set_to_none=True) - set_mode( - model=model, - mode=ModuleMode.GRADIENT_AGGREGATION, - tracked_module_names=tracked_module_names, - keep_factors=True, - ) - release_memory() - - scores: Dict[str, Optional[torch.Tensor]] = {} - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - scores[module.name] = None - else: - scores[ALL_MODULE_NAME] = None - - enable_amp = score_args.amp_dtype is not None - - if not aggregated_gradient_exist(model=model, tracked_module_names=tracked_module_names): - release_aggregated_gradient(model=model) - with tqdm( - total=len(train_loader), - desc="Computing pairwise scores (training gradient)", - bar_format=TQDM_BAR_FORMAT, - disable=not state.is_main_process or disable_tqdm, - ) as pbar: - for batch in train_loader: - batch = send_to_device(tensor=batch, device=state.device) - - with no_sync(model=model, state=state): - model.zero_grad(set_to_none=True) - with autocast(device_type=state.device.type, enabled=enable_amp, dtype=score_args.amp_dtype): - loss = task.compute_train_loss( - batch=batch, - model=model, - sample=False, - ) - scaler.scale(loss).backward() - - if factor_args.has_shared_parameters: - finalize_gradient_aggregation(model=model, tracked_module_names=tracked_module_names) - - pbar.update(1) +def pairwise_scores_exist( + output_dir: Path, + partition: Optional[PARTITION_TYPE] = None, +) -> bool: + """Checks if pairwise influence scores exist at the specified directory. - with torch.no_grad(): - if state.use_distributed: - synchronize_aggregated_gradient(model=model, tracked_module_names=tracked_module_names) - - compute_pairwise_scores_from_aggregation(model=model, tracked_module_names=tracked_module_names) - - with torch.no_grad(): - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - scores[module.name] = module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).clone().cpu() - else: - pairwise_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if pairwise_scores is None: - pairwise_scores = torch.zeros_like( - module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME), requires_grad=False - ) - pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) - scores[ALL_MODULE_NAME] = pairwise_scores.cpu() - release_scores(model=model) - - model.zero_grad(set_to_none=True) - set_mode( - model=model, - mode=ModuleMode.PRECONDITION_GRADIENT, - tracked_module_names=tracked_module_names, - keep_factors=True, - ) - release_preconditioned_gradient(model=model) - release_memory() + Args: + output_dir (Path): + Directory to check for scores. + partition (PARTITION_TYPE, optional): + Partition information, if any. - return scores + Returns: + bool: + `True` if scores exist, `False` otherwise. + """ + save_path = pairwise_scores_save_path( + output_dir=output_dir, + partition=partition, + ) + return save_path.exists() def compute_pairwise_scores_with_loaders( @@ -326,7 +147,7 @@ def compute_pairwise_scores_with_loaders( Args: loaded_factors (FACTOR_TYPE): - The factor results to load from, before computing the pairwise scores. + Computed factors. model (nn.Module): The model for which pairwise influence scores will be computed. state (State): @@ -340,38 +161,39 @@ def compute_pairwise_scores_with_loaders( train_loader (data.DataLoader): The data loader that will be used to compute training gradients. score_args (ScoreArguments): - Arguments related to computing pairwise scores. + Arguments for computing pairwise scores. factor_args (FactorArguments): - Arguments related to computing preconditioning factors. + Arguments used to compute factors. tracked_module_names (List[str], optional): A list of module names that pairwise scores will be computed. If not specified, scores will be computed for all available tracked modules. disable_tqdm (bool, optional): - Disables TQDM progress bars. Defaults to False. + Whether to disable the progress bar. Defaults to `False`. Returns: Dict[str, torch.Tensor]: A dictionary containing the module name and its pairwise influence scores. """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - mode=ModuleMode.PRECONDITION_GRADIENT, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - # Loads necessary factors before computing pairwise influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors( + model=model, + factor_name=name, + factors=loaded_factors[name], + clone=True, + ) + del loaded_factors + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) total_scores_chunks: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {} total_query_batch_size = per_device_query_batch_size * state.num_processes @@ -387,9 +209,9 @@ def compute_pairwise_scores_with_loaders( set_gradient_scale(model=model, gradient_scale=gradient_scale) dot_product_func = ( - _compute_aggregated_dot_products_with_loader + compute_aggregated_dot_products_with_loader if score_args.aggregate_train_gradients - else _compute_dot_products_with_loader + else compute_dot_products_with_loader ) with tqdm( @@ -412,21 +234,18 @@ def compute_pairwise_scores_with_loaders( scaler.scale(measurement).backward() if factor_args.has_shared_parameters: - finalize_preconditioned_gradient(model=model, tracked_module_names=tracked_module_names) + finalize_iteration(model=model, tracked_module_names=tracked_module_names) if state.use_distributed: - # Stacks preconditioned query gradient across multiple devices or nodes. - synchronize_preconditioned_gradient( + # Stack preconditioned query gradient across multiple devices or nodes. + synchronize_modules( model=model, tracked_module_names=tracked_module_names, num_processes=state.num_processes ) if query_index == len(query_loader) - 1 and query_remainder > 0: - # Removes duplicate data points if the dataset is not exactly divisible - # by the current batch size. - truncate_preconditioned_gradient( - model=model, tracked_module_names=tracked_module_names, keep_size=query_remainder - ) + # Removes duplicate data points if the dataset is not evenly divisible by the current batch size. + truncate(model=model, tracked_module_names=tracked_module_names, keep_size=query_remainder) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) - accumulate_preconditioned_gradient(model=model, tracked_module_names=tracked_module_names) num_accumulations += 1 if ( num_accumulations < score_args.query_gradient_accumulation_steps @@ -448,29 +267,26 @@ def compute_pairwise_scores_with_loaders( disable_tqdm=disable_tqdm, ) - with torch.no_grad(): - if state.is_main_process: - for module_name, current_scores in scores.items(): - if module_name not in total_scores_chunks: - total_scores_chunks[module_name] = [] - total_scores_chunks[module_name].append(current_scores) - state.wait_for_everyone() + if state.is_main_process: + for module_name, current_scores in scores.items(): + if module_name not in total_scores_chunks: + total_scores_chunks[module_name] = [] + total_scores_chunks[module_name].append(current_scores) + state.wait_for_everyone() num_accumulations = 0 pbar.update(1) - with torch.no_grad(): - if state.is_main_process: - for module_name in total_scores_chunks: - total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) + if state.is_main_process: + for module_name in total_scores_chunks: + total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - release_aggregated_gradient(model=model) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) - state.wait_for_everyone() + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return total_scores_chunks @@ -488,27 +304,24 @@ def compute_pairwise_query_aggregated_scores_with_loaders( tracked_module_names: Optional[List[str]], disable_tqdm: bool = False, ) -> Dict[str, torch.Tensor]: - """Computes pairwise influence scores (with query gradients aggregated) for a given model and task.""" + """Computes pairwise influence scores (with query gradients aggregated) for a given model and task. See + `compute_pairwise_scores_with_loaders` for detailed information.""" del per_device_query_batch_size - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - mode=ModuleMode.GRADIENT_AGGREGATION, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - # Loads necessary factors before computing pairwise influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + mode=ModuleMode.GRADIENT_AGGREGATION, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) + del loaded_factors + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) enable_amp = score_args.amp_dtype is not None scaler = GradScaler(enabled=enable_amp) @@ -517,9 +330,9 @@ def compute_pairwise_query_aggregated_scores_with_loaders( set_gradient_scale(model=model, gradient_scale=gradient_scale) dot_product_func = ( - _compute_aggregated_dot_products_with_loader + compute_aggregated_dot_products_with_loader if score_args.aggregate_train_gradients - else _compute_dot_products_with_loader + else compute_dot_products_with_loader ) with tqdm( @@ -541,17 +354,20 @@ def compute_pairwise_query_aggregated_scores_with_loaders( scaler.scale(measurement).backward() if factor_args.has_shared_parameters: - finalize_gradient_aggregation(model=model, tracked_module_names=tracked_module_names) + finalize_iteration(model=model, tracked_module_names=tracked_module_names) pbar.update(1) - with torch.no_grad(): - if state.use_distributed: - synchronize_aggregated_gradient(model=model, tracked_module_names=tracked_module_names) + if state.use_distributed: + synchronize_modules(model=model, tracked_module_names=tracked_module_names) - compute_preconditioned_gradient_from_aggregation(model=model, tracked_module_names=tracked_module_names) - accumulate_preconditioned_gradient(model=model, tracked_module_names=tracked_module_names) - release_aggregated_gradient(model=model) + set_mode( + model=model, + mode=ModuleMode.PRECONDITION_GRADIENT, + tracked_module_names=tracked_module_names, + release_memory=False, + ) + finalize_all_iterations(model=model, tracked_module_names=tracked_module_names) scores = dot_product_func( model=model, @@ -565,13 +381,10 @@ def compute_pairwise_query_aggregated_scores_with_loaders( disable_tqdm=disable_tqdm, ) - with torch.no_grad(): - # Clean up the memory. - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - release_aggregated_gradient(model=model) - set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) - state.wait_for_everyone() + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True) + state.wait_for_everyone() return scores diff --git a/kronfluence/score/self.py b/kronfluence/score/self.py index e659d8c..40c6af7 100644 --- a/kronfluence/score/self.py +++ b/kronfluence/score/self.py @@ -13,10 +13,13 @@ from kronfluence.module import TrackedModule from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( + accumulate_iterations, + finalize_iteration, finalize_preconditioned_gradient, finalize_self_measurement_scores, finalize_self_scores, get_tracked_module_names, + prepare_modules, release_scores, set_factors, set_gradient_scale, @@ -27,6 +30,7 @@ from kronfluence.task import Task from kronfluence.utils.constants import ( ALL_MODULE_NAME, + DISTRIBUTED_SYNC_INTERVAL, FACTOR_TYPE, PARTITION_TYPE, SCORE_TYPE, @@ -125,25 +129,21 @@ def compute_self_scores_with_loaders( Dict[str, torch.Tensor]: A dictionary containing the module name and its self-influence scores. """ - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - set_mode( - model=model, - mode=ModuleMode.SELF_SCORE, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - # Loads necessary factors before computing self-influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + set_mode( + model=model, + mode=ModuleMode.SELF_SCORE, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors(model=model, factor_name=name, factors=loaded_factors[name], clone=True) + del loaded_factors + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) dataset_size = len(train_loader.dataset) score_chunks: Dict[str, List[torch.Tensor]] = {} @@ -184,58 +184,55 @@ def compute_self_scores_with_loaders( scaler.scale(loss).backward() if factor_args.has_shared_parameters: - finalize_self_scores(model=model, tracked_module_names=tracked_module_names) - - with torch.no_grad(): - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() + ) + else: + self_scores = None + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False ) - else: - self_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if self_scores is None: - self_scores = torch.zeros_like( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False - ) - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) - score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) - release_scores(model=model) - - if state.use_distributed and total_steps % score_args.distributed_sync_interval == 0: - # Periodically synchronizes all processes to avoid timeout at the final synchronization. + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: state.wait_for_everyone() total_steps += 1 pbar.update(1) - with torch.no_grad(): - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode( - model=model, - mode=ModuleMode.DEFAULT, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - release_memory() - - total_scores: SCORE_TYPE = {} - for module_name, chunks in score_chunks.items(): - total_scores[module_name] = torch.cat(chunks, dim=0) - if state.use_distributed: - total_scores[module_name] = total_scores[module_name].to(device=state.device) - gather_list = None - if state.is_main_process: - gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) - if state.is_main_process: - total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() - state.wait_for_everyone() + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode( + model=model, + mode=ModuleMode.DEFAULT, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + release_memory() + + total_scores: SCORE_TYPE = {} + for module_name, chunks in score_chunks.items(): + total_scores[module_name] = torch.cat(chunks, dim=0) + if state.use_distributed: + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + torch.distributed.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() + state.wait_for_everyone() return total_scores @@ -253,19 +250,20 @@ def compute_self_measurement_scores_with_loaders( ) -> Dict[str, torch.Tensor]: """Computes self-influence scores with measurement (instead of the loss) for a given model and task. See `compute_self_scores_with_loaders` for the detailed docstring.""" - with torch.no_grad(): - update_factor_args(model=model, factor_args=factor_args) - update_score_args(model=model, score_args=score_args) - if tracked_module_names is None: - tracked_module_names = get_tracked_module_names(model=model) - # Loads necessary factors before computing self-influence scores. - if len(loaded_factors) > 0: - for name in loaded_factors: - set_factors( - model=model, - factor_name=name, - factors=loaded_factors[name], - ) + update_factor_args(model=model, factor_args=factor_args) + update_score_args(model=model, score_args=score_args) + if tracked_module_names is None: + tracked_module_names = get_tracked_module_names(model=model) + if len(loaded_factors) > 0: + for name in loaded_factors: + set_factors( + model=model, + factor_name=name, + factors=loaded_factors[name], + clone=True, + ) + del loaded_factors + prepare_modules(model=model, tracked_module_names=tracked_module_names, device=state.device) dataset_size = len(train_loader.dataset) score_chunks: Dict[str, List[torch.Tensor]] = {} @@ -299,7 +297,7 @@ def compute_self_measurement_scores_with_loaders( model=model, mode=ModuleMode.PRECONDITION_GRADIENT, tracked_module_names=tracked_module_names, - keep_factors=True, + release_memory=False, ) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) @@ -308,13 +306,13 @@ def compute_self_measurement_scores_with_loaders( scaler.scale(measurement).backward() if factor_args.has_shared_parameters: - finalize_preconditioned_gradient(model=model, tracked_module_names=tracked_module_names) + finalize_iteration(model=model, tracked_module_names=tracked_module_names) set_mode( model=model, mode=ModuleMode.SELF_MEASUREMENT_SCORE, tracked_module_names=tracked_module_names, - keep_factors=True, + release_memory=False, ) with no_sync(model=model, state=state): model.zero_grad(set_to_none=True) @@ -327,57 +325,54 @@ def compute_self_measurement_scores_with_loaders( scaler.scale(loss).backward() if factor_args.has_shared_parameters: - finalize_self_measurement_scores(model=model, tracked_module_names=tracked_module_names) - - with torch.no_grad(): - if score_args.compute_per_module_scores: - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - score_chunks[module.name].append( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() + finalize_iteration(model=model, tracked_module_names=tracked_module_names) + + if score_args.compute_per_module_scores: + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + score_chunks[module.name].append( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).clone().cpu() + ) + else: + self_scores = None + for module in model.modules(): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + if self_scores is None: + self_scores = torch.zeros_like( + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False ) - else: - self_scores = None - for module in model.modules(): - if isinstance(module, TrackedModule) and module.name in tracked_module_names: - if self_scores is None: - self_scores = torch.zeros_like( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME), requires_grad=False - ) - self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) - score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) - release_scores(model=model) - - if state.use_distributed and total_steps % score_args.distributed_sync_interval == 0: - # Periodically synchronizes all processes to avoid timeout at the final synchronization. + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) + score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) + accumulate_iterations(model=model, tracked_module_names=tracked_module_names) + + if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: state.wait_for_everyone() total_steps += 1 pbar.update(1) - with torch.no_grad(): - model.zero_grad(set_to_none=True) - if enable_amp: - set_gradient_scale(model=model, gradient_scale=1.0) - set_mode( - model=model, - mode=ModuleMode.DEFAULT, - tracked_module_names=tracked_module_names, - keep_factors=False, - ) - release_memory() - - total_scores: SCORE_TYPE = {} - for module_name, chunks in score_chunks.items(): - total_scores[module_name] = torch.cat(chunks, dim=0) - if state.use_distributed: - total_scores[module_name] = total_scores[module_name].to(device=state.device) - gather_list = None - if state.is_main_process: - gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] - torch.distributed.gather(total_scores[module_name], gather_list) - if state.is_main_process: - total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() - state.wait_for_everyone() + model.zero_grad(set_to_none=True) + if enable_amp: + set_gradient_scale(model=model, gradient_scale=1.0) + set_mode( + model=model, + mode=ModuleMode.DEFAULT, + tracked_module_names=tracked_module_names, + release_memory=True, + ) + release_memory() + + total_scores: SCORE_TYPE = {} + for module_name, chunks in score_chunks.items(): + total_scores[module_name] = torch.cat(chunks, dim=0) + if state.use_distributed: + total_scores[module_name] = total_scores[module_name].to(device=state.device) + gather_list = None + if state.is_main_process: + gather_list = [torch.zeros_like(total_scores[module_name]) for _ in range(state.num_processes)] + torch.distributed.gather(total_scores[module_name], gather_list) + if state.is_main_process: + total_scores[module_name] = torch.cat(gather_list, dim=0)[:dataset_size].cpu() + state.wait_for_everyone() return total_scores diff --git a/kronfluence/task.py b/kronfluence/task.py index 09df889..09a1736 100644 --- a/kronfluence/task.py +++ b/kronfluence/task.py @@ -9,10 +9,14 @@ class Task(ABC): """Abstract base class for task definitions. Extend this class to implement specific tasks (e.g., regression, classification, language modeling) - with custom pipelines (models, data loaders, training objectives). + with custom pipelines (e.g., models, data loaders, training objectives). + + Attributes: + enable_post_process_per_sample_gradient (bool): + Flag to enable post-processing of per-sample gradients. Defaults to `False`. """ - do_post_process_per_sample_gradient: bool = False + enable_post_process_per_sample_gradient: bool = False @abstractmethod def compute_train_loss( @@ -21,21 +25,21 @@ def compute_train_loss( model: nn.Module, sample: bool = False, ) -> torch.Tensor: - """Computes training loss for a given batch and model. + """Computes the training loss for a given batch and model. Args: batch (Any): - Batch of data sourced from the DataLoader. + A batch of data from the DataLoader. model (nn.Module): - The PyTorch model for loss computation. + The PyTorch model used for loss computation. sample (bool): Indicates whether to sample from the model's outputs or to use the actual targets from the - batch. Defaults to False. The case where `sample` is set to True must be implemented to + batch. Defaults to `False`. The case where `sample` is set to `True` must be implemented to approximate the true Fisher. Returns: torch.Tensor: - The computed loss as a tensor. + The computed loss as a scalar tensor. """ raise NotImplementedError("Subclasses must implement the `compute_train_loss` method.") @@ -45,67 +49,68 @@ def compute_measurement( batch: Any, model: nn.Module, ) -> torch.Tensor: - """Computes a measurable quantity (e.g., loss, logit, log probability) for a given batch and model. - This is defined as f(θ) from https://arxiv.org/pdf/2308.03296.pdf. + """Computes a measurable quantity for a given batch and model. + + This method calculates f(θ) as defined in https://arxiv.org/pdf/2308.03296.pdf. The measurable quantity + can be a loss, logit, log probability, or any other relevant metric for the task. Args: batch (Any): - Batch of data sourced from the DataLoader. + A batch of data from the DataLoader. model (nn.Module): - The PyTorch model for measurement computation. + The PyTorch model used for measurement computation. Returns: torch.Tensor: - The measurable quantity as a tensor. + The computed measurable quantity as a tensor. """ raise NotImplementedError("Subclasses must implement the `compute_measurement` method.") - def tracked_modules(self) -> Optional[List[str]]: - """Specifies modules for influence score computations. + def get_influence_tracked_modules(self) -> Optional[List[str]]: + """Specifies which modules should be tracked for influence score computations. - Returns None by default, applying computations to all supported modules (e.g., nn.Linear, nn.Conv2d). - Subclasses can override this method to return a list of specific module names if influence functions + Override this method in subclasses to return a list of specific module names if influence functions should only be computed for a subset of the model. Returns: Optional[List[str]]: - A list of module names for which to compute influence functions, or None to indicate that - influence functions should be computed for all applicable modules. + A list of module names to compute influence functions for, or None to compute for + all applicable modules (e.g., nn.Linear, nn.Conv2d). """ def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: - """Returns masks for data points within a batch that have been padded extra tokens to ensure - consistent length across the batch. Typically, it returns None for models or datasets not requiring - masking. + """Returns attention masks for padded sequences in a batch. - See https://huggingface.co/docs/transformers/en/glossary#attention-mask. + This method is typically used for models or datasets that require masking, such as transformer-based + architectures. For more information, see: https://huggingface.co/docs/transformers/en/glossary#attention-mask. Args: batch (Any): - Batch of data sourced from the DataLoader. + A batch of data from the DataLoader. Returns: Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: - A binary tensor as the mask for the batch, or None if padding is not used. The mask dimensions should - match `batch_size x num_seq`. For models requiring different masks for different modules - (e.g., encoder-decoder architectures), returns a dictionary mapping module names to their - corresponding masks. + - `None` if padding is not used. + - A binary tensor with dimension `batch_size x num_seq` as the mask for the batch. + - A dictionary mapping module names to their corresponding masks for models requiring different + masks for different modules (e.g., encoder-decoder architectures). """ def post_process_per_sample_gradient(self, module_name: str, gradient: torch.Tensor) -> torch.Tensor: - """Post-processes the per-sample-gradient of the module with the given name. The attribute - `do_post_process_per_sample_gradient` needs to be set to `True` to enable this post-processing. + """Post-processes the per-sample gradient of a specific module. + + This method is called only if `do_post_process_per_sample_gradient` is set to `True`. + Override this method in subclasses to implement custom gradient post-processing. Args: module_name (str): - Name of the module. + The name of the module whose gradient is being processed. gradient (torch.Tensor): - The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x gradient_dim x activation_dim`. + The per-sample gradient tensor with dimension `batch_size x gradient_dim x activation_dim`. Returns: torch.Tensor: - The modified per-sample-gradient tensor. + The modified per-sample gradient tensor. """ del module_name return gradient diff --git a/kronfluence/utils/constants.py b/kronfluence/utils/constants.py index 120e189..fa2c289 100644 --- a/kronfluence/utils/constants.py +++ b/kronfluence/utils/constants.py @@ -4,6 +4,8 @@ import torch +DISTRIBUTED_SYNC_INTERVAL = 1_000 + FACTOR_TYPE = Dict[str, Dict[str, torch.Tensor]] PARTITION_TYPE = Tuple[int, int] SCORE_TYPE = Dict[str, torch.Tensor] @@ -54,7 +56,8 @@ # Preconditioned per-sample gradient. PRECONDITIONED_GRADIENT_NAME = "preconditioned_gradient" -ACCUMULATED_PRECONDITIONED_GRADIENT_NAME = "aggregated_preconditioned_gradient" +ACCUMULATED_PRECONDITIONED_GRADIENT_NAME = "accumulated_preconditioned_gradient" +AGGREGATED_GRADIENT_NAME = "aggregated_gradient" # Pairwise influence scores. PAIRWISE_SCORE_MATRIX_NAME = "pairwise_score_matrix" # Self-influence scores. diff --git a/kronfluence/utils/dataset.py b/kronfluence/utils/dataset.py index 920cf55..e2d637e 100644 --- a/kronfluence/utils/dataset.py +++ b/kronfluence/utils/dataset.py @@ -16,8 +16,11 @@ @dataclass class DataLoaderKwargs(KwargsHandler): - """The object used to customize `DataLoader`. Please refer to https://pytorch.org/docs/stable/data.html for - detailed information of each argument. The default arguments are copied from PyTorch version 2.3. + """Customization options for `DataLoader`. + + This class encapsulates the arguments used to customize PyTorch's `DataLoader`. Default values are based on + PyTorch version 2.3. For detailed information on each argument, refer to: + https://pytorch.org/docs/stable/data.html. """ num_workers: int = 0 @@ -33,7 +36,21 @@ class DataLoaderKwargs(KwargsHandler): def make_indices_partition(total_data_examples: int, partition_size: int) -> List[Tuple[int, int]]: - """Returns partitioned indices from the total data examples.""" + """Partitions data indices into approximately equal-sized bins. + + Args: + total_data_examples (int): + Total number of data examples. + partition_size (int): + Number of partitions to create. + + Returns: + List[Tuple[int, int]]: + List of tuples, each containing start and end indices for a partition. + + Raises: + ValueError: If `total_data_examples` is less than `partition_size`. + """ if total_data_examples < partition_size: raise ValueError("The total data examples must be equal or greater than the partition size.") # See https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length. @@ -47,8 +64,26 @@ def make_indices_partition(total_data_examples: int, partition_size: int) -> Lis def find_executable_batch_size(func: Callable, start_batch_size: int) -> int: - """Finds executable batch size for calling the function that does not encounter OOM error. The code is motivated - from https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/memory.py#L83.""" + """Finds the largest batch size that can be executed without OOM errors. + + This function progressively reduces the batch size until it finds a size that can be executed + without running out of memory. The code is motivated from: + https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/memory.py#L83 + + Args: + func (Callable): + Function to test with different batch sizes. + start_batch_size (int): + Initial batch size to try. + + Returns: + int: + The largest executable batch size. + + Raises: + RuntimeError: + If no executable batch size is found (reaches zero). + """ batch_size = start_batch_size while True: @@ -67,10 +102,9 @@ def find_executable_batch_size(func: Callable, start_batch_size: int) -> int: class DistributedEvalSampler(Sampler[T_co]): - """DistributedEvalSampler is different from `DistributedSampler`: it does not add extra samples to make - the dataset evenly divisible. DistributedEvalSampler should not be used for training; the distributed processes - could hang forever. See this issue for details: https://github.com/pytorch/pytorch/issues/22584. + """Sampler for distributed setting without adding extra samples. + Unlike `DistributedSampler`, it does not add extra samples to make the dataset evenly divisible across processes. The code is adapted from https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py. """ @@ -112,10 +146,10 @@ def __len__(self) -> int: class DistributedSamplerWithStack(Sampler[T_co]): - """DistributedSampleWithStack is different from `DistributedSampler`. Instead of subsampling, - it stacks the dataset. For example, when `num_replicas` is 3, and the dataset of [0, ..., 9] is given, - the first, second, and third rank should have [0, 1, 2], [3, 4, 5], and [6, 7, 8], respectively. However, - it still adds extra samples to make the dataset evenly divisible (different from DistributedEvalSampler). + """Sampler that stacks the dataset for distributed setting. + + Instead of subsampling, this sampler stacks the dataset across processes. It ensures even distribution by + adding padding samples if necessary. """ def __init__( # pylint: disable=super-init-not-called diff --git a/kronfluence/utils/exceptions.py b/kronfluence/utils/exceptions.py index 597ca5b..e64f334 100644 --- a/kronfluence/utils/exceptions.py +++ b/kronfluence/utils/exceptions.py @@ -3,12 +3,12 @@ class FactorsNotFoundError(ValueError): class TrackedModuleNotFoundError(ValueError): - """Exception raised when the tracked module is not found.""" + """Exception raised when a tracked module is not found in the model.""" class IllegalTaskConfigurationError(ValueError): - """Exception raised when the provided task is determined to be invalid.""" + """Exception raised when the provided task configuration is determined to be invalid.""" class UnsupportableModuleError(NotImplementedError): - """Exception raised when the provided module is not supported.""" + """Exception raised when the provided module is not supported by the current implementation.""" diff --git a/kronfluence/utils/logger.py b/kronfluence/utils/logger.py index 04af511..5c90d5e 100644 --- a/kronfluence/utils/logger.py +++ b/kronfluence/utils/logger.py @@ -20,21 +20,33 @@ class MultiProcessAdapter(logging.LoggerAdapter): - """An adapter to assist with logging in multiprocess. + """An adapter for logging in multiprocess environments. - The code is copied from https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py with - minor modifications. + The code is adapted from: https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py. """ def log(self, level: int, msg: str, *args, **kwargs) -> None: - """Delegates logger call after checking if it should log.""" + """Log a message if logging is enabled for this process.""" if self.isEnabledFor(level) and not self.extra["disable_log"]: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) def get_logger(name: str, disable_log: bool = False, log_level: int = None) -> MultiProcessAdapter: - """Returns the logger with an option to disable logging.""" + """Creates and returns a logger with optional disabling and log level setting. + + Args: + name (str): + Name of the logger. + disable_log (bool): + Whether to disable logging. Defaults to `False`. + log_level (int): + Logging level to set. Defaults to `None`. + + Returns: + MultiProcessAdapter: + Configured logger adapter. + """ logger = logging.getLogger(name) if log_level is not None: logger.setLevel(log_level) @@ -43,16 +55,15 @@ def get_logger(name: str, disable_log: bool = False, log_level: int = None) -> M class Profiler: - """Profiling object to measure the time taken to run a certain operation. The profiler is helpful - for checking any bottlenecks in the code. + """A profiling utility to measure execution time of operations. - The code is modified from: + The code is adapted from: - https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers. - https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/profiler.py. """ def __init__(self, state: State) -> None: - """Initializes an instance of the Profiler class. + """Initializes an instance of the `Profiler` class. Args: state (State): @@ -63,7 +74,7 @@ def __init__(self, state: State) -> None: self.recorded_durations = defaultdict(list) def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" + """Start recording an action.""" if not self.state.is_main_process: return if action_name in self.current_actions: @@ -71,7 +82,7 @@ def start(self, action_name: str) -> None: self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: - """Defines how to record the duration once an action is complete.""" + """Stop recording an action and log its duration.""" if not self.state.is_main_process: return end_time = _get_monotonic_time() @@ -83,7 +94,7 @@ def stop(self, action_name: str) -> None: @contextmanager def profile(self, action_name: str) -> Generator: - """Yields a context manager to encapsulate the scope of a profiled action.""" + """Context manager for profiling an action.""" try: self.start(action_name) yield action_name @@ -92,6 +103,7 @@ def profile(self, action_name: str) -> Generator: @torch.no_grad() def _make_report(self) -> Tuple[_TABLE_DATA, float, float]: + """Generate a report of profiled actions.""" total_duration = 0.0 for a, d in self.recorded_durations.items(): d_tensor = torch.tensor(d, dtype=torch.float64, requires_grad=False) @@ -110,7 +122,7 @@ def _make_report(self) -> Tuple[_TABLE_DATA, float, float]: return report, total_calls, total_duration def summary(self) -> str: - """Returns a formatted summary for the Profiler.""" + """Generate a formatted summary of the profiling results.""" sep = os.linesep output_string = "Profiler Report:" @@ -143,37 +155,34 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str class PassThroughProfiler(Profiler): - """A pass through Profiler objective that does not record timing.""" + """A no-op profiler that doesn't record any timing information.""" def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" return def stop(self, action_name: str) -> None: - """Defines how to record the duration once an action is complete.""" return def summary(self) -> str: - """Returns a formatted summary for the Profiler.""" return "" class TorchProfiler(Profiler): - """A PyTorch Profiler objective that provides detailed profiling information: - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html. + """A profiler that utilizes PyTorch's built-in profiling capabilities. + + This profiler provides detailed information about PyTorch operations, including CPU and CUDA events. + It's useful for low-level profiling in PyTorch. - This is useful for low-level profiling in PyTorch, and is not used by default. + Note: This is not used by default and is intended for detailed performance analysis. """ def __init__(self, state: State) -> None: - """Initializes an instance of the PyTorch Profiler class.""" super().__init__(state=state) self.actions: list = [] self.trace_outputs: list = [] self._set_up_torch_profiler() def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" if action_name in self.current_actions: raise ValueError(f"Attempted to start {action_name} which has already started.") # Set dummy value, since only used to track duplicate actions. @@ -182,14 +191,12 @@ def start(self, action_name: str) -> None: self._torch_prof.start() def stop(self, action_name: str) -> None: - """Defines how to stop recording an action.""" if action_name not in self.current_actions: raise ValueError(f"Attempting to stop recording an action " f"({action_name}) which was never started.") _ = self.current_actions.pop(action_name) self._torch_prof.stop() def _set_up_torch_profiler(self) -> None: - """Creates the PyTorch profiler object with the necessary arguments.""" self._torch_prof = t_prof.profile( activities=[t_prof.ProfilerActivity.CPU, t_prof.ProfilerActivity.CUDA], record_shapes=True, @@ -200,7 +207,6 @@ def _set_up_torch_profiler(self) -> None: ) def _trace_handler(self, p) -> None: - """Adds the PyTorch Profiler trace output to a list once it is ready.""" # Set metric to sort based on device. is_cpu = self.state.device == torch.device("cpu") sort_by_metric = "self_cpu_time_total" if is_cpu else "self_cuda_time_total" @@ -218,12 +224,10 @@ def _trace_handler(self, p) -> None: self.recorded_durations[self.actions[-1]].append(total_time) def _reset_output(self) -> None: - """Resets actions and outputs list.""" self.actions = [] self.trace_outputs = [] def _high_level_summary(self) -> str: - """Returns a formatted high level summary for the PyTorch Profiler.""" sep = os.linesep output_string = "Overall PyTorch Profiler Report:" @@ -255,7 +259,6 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str return output_string def summary(self) -> str: - """Returns a formatted summary for the PyTorch Profiler.""" assert len(self.actions) == len(self.trace_outputs), ( "Mismatch in the number of actions and outputs collected: " + f"# Actions: {len(self.actions)}, # Ouptuts: {len(self.trace_outputs)}" @@ -272,7 +275,7 @@ def summary(self) -> str: return summary -# Timing utilities copied from +# Timing utilities copied from: # https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/pytorch_utils.py. def _get_monotonic_time() -> float: """Gets the monotonic time after the CUDA synchronization if necessary.""" diff --git a/kronfluence/utils/model.py b/kronfluence/utils/model.py index 1a71a79..e58bd33 100644 --- a/kronfluence/utils/model.py +++ b/kronfluence/utils/model.py @@ -20,21 +20,25 @@ def apply_ddp( rank: int, world_size: int, ) -> DistributedDataParallel: - """Applies DistributedDataParallel (DDP) to the given model. + """Applies DistributedDataParallel (DDP) to the given PyTorch model. Args: model (nn.Module): - The model for which DDP will be applied. + The PyTorch model to be parallelized. local_rank (int): - The local rank of the current process. + The local rank of the current process within its node. rank (int): - The rank of the current process. + The global rank of the current process across all nodes. world_size (int): - The total number of processes. + The total number of processes in the distributed setup. Returns: DistributedDataParallel: - The model wrapped with DDP. + The input model wrapped with DDP. + + Raises: + RuntimeError: + If the distributed initialization fails. """ dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device(f"cuda:{local_rank}") @@ -61,31 +65,35 @@ def apply_fsdp( is_transformer: bool = False, layer_to_wrap: Optional[nn.Module] = None, ) -> FSDP: - """Applies FullyShardedDataParallel (FSDP) to the given model. + """Applies FullyShardedDataParallel (FSDP) to the given PyTorch model. Args: model (nn.Module): - The model for which FSDP will be applied. + The PyTorch model to be parallelized. local_rank (int): - The local rank of the current process. + The local rank of the current process within its node. rank (int): - The rank of the current process. + The global rank of the current process across all nodes. world_size (int): - The total number of processes. + The total number of processes in the distributed setup. sharding_strategy (str): - The sharding strategy to use. Defaults to "FULL_SHARD". + The FSDP sharding strategy to use. Defaults to "FULL_SHARD". cpu_offload (bool): - Whether to offload parameters to CPU. Check - https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.CPUOffload. Defaults to True. + Whether to offload parameters to CPU. Defaults to `True`. is_transformer (bool): - Whether the model is a transformer model. Defaults to False. + Whether the model is a transformer. Defaults to `False`. layer_to_wrap (nn.Module, optional): - The specific layer to wrap for transformer models. Required if `is_transformer` is True. - Defaults to None. + The specific layer to wrap for transformer models. Required if `is_transformer` is `True`. Returns: FullyShardedDataParallel: - The model wrapped with FSDP. + The input model wrapped with FSDP. + + Raises: + ValueError: + If an invalid sharding strategy is provided or if `layer_to_wrap` is not provided for transformer models. + RuntimeError: + If the distributed initialization fails. """ dist.init_process_group("nccl", rank=rank, world_size=world_size) device = torch.device(f"cuda:{local_rank}") diff --git a/kronfluence/utils/save.py b/kronfluence/utils/save.py index 90ad112..bb800c8 100644 --- a/kronfluence/utils/save.py +++ b/kronfluence/utils/save.py @@ -5,38 +5,93 @@ import torch from safetensors import safe_open +# Constants for file naming conventions. FACTOR_SAVE_PREFIX = "factors_" SCORE_SAVE_PREFIX = "scores_" - FACTOR_ARGUMENTS_NAME = "factor" SCORE_ARGUMENTS_NAME = "score" def load_file(path: Path) -> Dict[str, torch.Tensor]: - """Loads a dictionary of tensors from the path.""" - load_dict = {} - with safe_open(path, framework="pt", device="cpu") as f: - for key in f.keys(): - load_dict[key] = f.get_tensor(name=key) - return load_dict + """Loads a dictionary of tensors from a file using `safetensors`. + + Args: + path (Path): + The path to the file containing tensor data. + + Returns: + Dict[str, torch.Tensor]: + A dictionary where keys are tensor names and values are the corresponding tensors. + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}.") + try: + with safe_open(path, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} + except Exception as e: + raise RuntimeError(f"Error loading file {path}: {str(e)}") from e def save_json(obj: Any, path: Path) -> None: - """Saves the object to a JSON file.""" - with open(path, "w", encoding="utf-8") as f: - json.dump(obj, f, indent=4) + """Saves an object to a JSON file. + + This function serializes the given object to JSON format and writes it to a file. + + Args: + obj (Any): + The object to be saved. Must be JSON-serializable. + path (Path): + The path where the JSON file will be saved. + """ + path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=4, ensure_ascii=False) + except TypeError as e: + raise TypeError(f"Object is not JSON-serializable: {str(e)}") from e + except Exception as e: + raise IOError(f"Error saving JSON file {path}: {str(e)}") from e def load_json(path: Path) -> Dict[str, Any]: - """Loads an object from the JSON file.""" - with open(path, "rb") as f: - obj = json.load(f) - return obj + """Loads an object from a JSON file. + + Args: + path (Path): + The path to the JSON file to be loaded. + + Returns: + Dict[str, Any]: + The object loaded from the JSON file. + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}.") + with open(path, "r", encoding="utf-8") as f: + return json.load(f) -@torch.no_grad() def verify_models_equivalence(state_dict1: Dict[str, torch.Tensor], state_dict2: Dict[str, torch.Tensor]) -> bool: - """Checks if two models are equivalent given their `state_dict`.""" + """Check if two models are equivalent given their state dictionaries. + + This function compares two model state dictionaries to determine if they represent + equivalent models. It checks for equality in the number of parameters, parameter names, + and parameter values (within a small tolerance). + + Args: + state_dict1 (Dict[str, torch.Tensor]): + The state dictionary of the first model. + state_dict2 (Dict[str, torch.Tensor]): + The state dictionary of the second model. + + Returns: + bool: + `True` if the models are equivalent, `False` otherwise. + + Notes: + - The function uses a relative tolerance of 1.3e-6 and an absolute tolerance of 1e-5 + when comparing tensor values. + - Tensors are compared in float32 precision on the CPU to ensure consistency. + """ if len(state_dict1) != len(state_dict2): return False diff --git a/kronfluence/utils/state.py b/kronfluence/utils/state.py index 08705dd..c145255 100644 --- a/kronfluence/utils/state.py +++ b/kronfluence/utils/state.py @@ -23,12 +23,11 @@ class State: _shared_state: Dict[str, Any] = SharedDict() def __init__(self, cpu: bool = False) -> None: - """Initializes an instance of the State class. + """Initializes an instance of the `State` class. Args: cpu (bool): - Specifies whether the analysis should be explicitly performed using the CPU. - Defaults to False, utilizing GPU resources if available. + If `True`, forces the use of CPU even if GPUs are available. Defaults to `False`. """ self.__dict__ = self._shared_state @@ -51,6 +50,7 @@ def __init__(self, cpu: bool = False) -> None: self.device = torch.device("cpu") if self.cpu else self.default_device def __repr__(self) -> str: + """Provides a string representation of the State instance.""" return ( f"Num processes: {self.num_processes}\n" f"Process index: {self.process_index}\n" @@ -60,50 +60,54 @@ def __repr__(self) -> str: @staticmethod def _reset_state() -> None: - """Resets `_shared_state`, is used internally and should not be called.""" + """Resets the shared state. For internal use only.""" State._shared_state.clear() @property def initialized(self) -> bool: - """Returns whether the `PartialState` has been initialized.""" + """Checks if the State has been initialized.""" return self._shared_state != {} @property def use_distributed(self) -> bool: - """Whether the State is configured for distributed training.""" + """Checks if the setup is configured for distributed setting.""" return self.num_processes > 1 @property def is_main_process(self) -> bool: - """Returns whether the current process is the main process.""" + """Checks if the current process is the main process.""" return self.process_index == 0 @property def is_local_main_process(self) -> bool: - """Returns whether the current process is the main process on the local node.""" + """Checks if the current process is the main process on the local node.""" return self.local_process_index == 0 @property def is_last_process(self) -> bool: - """Returns whether the current process is the last one.""" + """Checks if the current process is the last one.""" return self.process_index == self.num_processes - 1 def wait_for_everyone(self) -> None: - """Will stop the execution of the current process until every other process has reached that point - (so this does nothing when the script is only run in one process).""" + """Synchronizes all processes. + + This method will pause the execution of the current process until all other processes + reach this point. It has no effect in single-process execution. + """ if self.use_distributed: dist.barrier() @property def default_device(self) -> torch.device: - """Finds the default device currently available.""" + """Determines the default device (CUDA if available, otherwise CPU).""" if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def release_memory() -> None: - """Releases the memory by calling `gc.collect()` and `torch.cuda.empty_cache()`.""" + """Releases unused memory. This function calls Python's garbage collector and empties CUDA cache + if CUDA is available.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -111,12 +115,23 @@ def release_memory() -> None: @contextlib.contextmanager def no_sync(model: nn.Module, state: State) -> Callable: - """A context manager to avoid DDP synchronization. The code is adapted from - https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/accelerator.py#L852.""" + """A context manager to temporarily disable gradient synchronization in distributed setting. + + Args: + model (nn.Module): + The PyTorch model. + state (State): + The current process state. + + Yields: + A context where gradient synchronization is disabled (if applicable). + + Note: + For FullyShardedDataParallel (FSDP) models, this may result in higher memory usage. + See: https://pytorch.org/docs/stable/fsdp.html. + """ context = contextlib.nullcontext - # `no_sync()` for FSDP instance can result in higher memory usage, detailed in: - # https://pytorch.org/docs/stable/fsdp.html. if state.use_distributed and not isinstance(model, FullyShardedDataParallel): context = getattr(model, "no_sync", context) diff --git a/kronfluence/version.py b/kronfluence/version.py index 3dc1f76..5becc17 100644 --- a/kronfluence/version.py +++ b/kronfluence/version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "1.0.0" diff --git a/tests/factors/test_covariances.py b/tests/factors/test_covariances.py index 9c9ef4c..ca5fa82 100644 --- a/tests/factors/test_covariances.py +++ b/tests/factors/test_covariances.py @@ -401,7 +401,7 @@ def test_covariance_matrices_amp( train_size: int, seed: int, ) -> None: - # Covariance matrices should be similar when AMP is enabled. + # Covariance matrices should be similar even when AMP is enabled. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -511,3 +511,62 @@ def test_covariance_matrices_gradient_checkpoint( atol=ATOL, rtol=RTOL, ) + + +@pytest.mark.parametrize("train_size", [100]) +@pytest.mark.parametrize("seed", [8, 9]) +def test_covariance_matrices_inplace( + train_size: int, + seed: int, +) -> None: + # Covariance matrices should be the identical for with and without in-place ReLU. + model, train_dataset, _, data_collator, task = prepare_test( + test_name="conv", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + + factor_args = pytest_factor_arguments() + analyzer.fit_covariance_matrices( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + per_device_batch_size=8, + overwrite_output_dir=True, + factor_args=factor_args, + ) + covariance_factors = analyzer.load_covariance_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + model, _, _, _, task = prepare_test( + test_name="conv_inplace", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + analyzer.fit_covariance_matrices( + factors_name=custom_factors_name("inplace"), + dataset=train_dataset, + per_device_batch_size=4, + overwrite_output_dir=True, + factor_args=factor_args, + ) + inplace_covariance_factors = analyzer.load_covariance_matrices( + factors_name=custom_factors_name("inplace"), + ) + + assert check_tensor_dict_equivalence( + covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], + inplace_covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME], + atol=ATOL, + rtol=RTOL, + ) diff --git a/tests/factors/test_eigens.py b/tests/factors/test_eigendecompositions.py similarity index 98% rename from tests/factors/test_eigens.py rename to tests/factors/test_eigendecompositions.py index ee74da2..3198754 100644 --- a/tests/factors/test_eigens.py +++ b/tests/factors/test_eigendecompositions.py @@ -48,7 +48,7 @@ def test_perform_eigendecomposition( factors_name=DEFAULT_FACTORS_NAME, factor_args=factor_args, dataset=train_dataset, - per_device_batch_size=4, + per_device_batch_size=None, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) diff --git a/tests/factors/test_lambdas.py b/tests/factors/test_lambdas.py index 5ce2d11..751193c 100644 --- a/tests/factors/test_lambdas.py +++ b/tests/factors/test_lambdas.py @@ -29,6 +29,7 @@ "repeated_mlp", "conv", "bert", + "roberta", "gpt", "gpt_checkpoint", ], @@ -500,3 +501,61 @@ def test_lambda_matrices_shared_parameters( for name in LAMBDA_FACTOR_NAMES: assert check_tensor_dict_equivalence(lambda_factors[name], shared_lambda_factors[name], atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize("train_size", [121]) +@pytest.mark.parametrize("seed", [8]) +def test_lambda_matrices_inplace( + train_size: int, + seed: int, +) -> None: + # Lambda matrices should be the identical for with and without in-place ReLU. + model, train_dataset, _, data_collator, task = prepare_test( + test_name="conv", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + + factor_args = pytest_factor_arguments() + analyzer.fit_all_factors( + factors_name=DEFAULT_FACTORS_NAME, + dataset=train_dataset, + per_device_batch_size=5, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + lambda_factors = analyzer.load_lambda_matrices( + factors_name=DEFAULT_FACTORS_NAME, + ) + + model, _, _, _, task = prepare_test( + test_name="conv_inplace", + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + analyzer.fit_all_factors( + factors_name=custom_factors_name("inplace"), + dataset=train_dataset, + per_device_batch_size=6, + overwrite_output_dir=True, + factor_args=factor_args, + dataloader_kwargs=kwargs, + ) + inplace_lambda_factors = analyzer.load_lambda_matrices( + factors_name=custom_factors_name("inplace"), + ) + + for name in LAMBDA_FACTOR_NAMES: + assert check_tensor_dict_equivalence(lambda_factors[name], inplace_lambda_factors[name], atol=ATOL, rtol=RTOL) diff --git a/tests/modules/test_per_sample_gradients.py b/tests/modules/test_per_sample_gradients.py index ed0a459..f2ae5e5 100644 --- a/tests/modules/test_per_sample_gradients.py +++ b/tests/modules/test_per_sample_gradients.py @@ -14,7 +14,7 @@ from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode, TrackedModule from kronfluence.module.utils import ( - finalize_preconditioned_gradient, + finalize_iteration, get_tracked_module_names, set_mode, update_factor_args, @@ -157,7 +157,7 @@ def test_for_loop_per_sample_gradient_equivalence( loss.backward() if test_name == "repeated_mlp": - finalize_preconditioned_gradient(model=model, tracked_module_names=tracked_modules_names) + finalize_iteration(model=model, tracked_module_names=tracked_modules_names) module_gradients = {} for module in model.modules(): @@ -251,7 +251,7 @@ def test_mean_gradient_equivalence( loss.backward() if test_name == "repeated_mlp": - finalize_preconditioned_gradient(model=model, tracked_module_names=tracked_modules_names) + finalize_iteration(model=model, tracked_module_names=tracked_modules_names) module_gradients = {} for module in model.modules(): diff --git a/tests/modules/test_scores.py b/tests/modules/test_scores.py deleted file mode 100644 index abac1ac..0000000 --- a/tests/modules/test_scores.py +++ /dev/null @@ -1,135 +0,0 @@ -# pylint: skip-file -import time - -import opt_einsum -import torch -from accelerate.utils import set_seed -from opt_einsum import DynamicProgramming - - -def test_compute_score_matmul( - seed: int = 0, -) -> None: - input_dim = 1024 - output_dim = 2048 - batch_dim = 16 - query_batch_dim = 64 - set_seed(seed) - - gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) - new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) - - score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) - path = opt_einsum.contract_path("toi,qoi->tq", gradient, new_gradient) - print(path) - - unsqueeze_score = opt_einsum.contract("t...,q...->tq", gradient, new_gradient) - assert torch.allclose(score, unsqueeze_score) - path = opt_einsum.contract_path("t...,q...->tq", gradient, new_gradient) - print(path) - - -def test_pairwise_score_computation( - seed: int = 0, -) -> None: - input_dim = 4096 - output_dim = 1024 - token_dim = 512 - batch_dim = 32 - query_batch_dim = 16 - rank = 16 - - set_seed(seed) - - lr_gradient1 = torch.rand(size=(query_batch_dim, output_dim, rank), dtype=torch.float64) - lr_gradient2 = torch.rand(size=(query_batch_dim, rank, input_dim), dtype=torch.float64) - lr_gradient = torch.bmm(lr_gradient1, lr_gradient2) - - U, S, V = torch.linalg.svd( - lr_gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - - output_gradient = torch.rand(size=(batch_dim, token_dim, output_dim), dtype=torch.float64) - input_activation = torch.rand(size=(batch_dim, token_dim, input_dim), dtype=torch.float64) - - start_time = time.time() - train_gradient = opt_einsum.contract("b...i,b...o->bio", output_gradient, input_activation) - gt = opt_einsum.contract("qio,bio->qb", lr_gradient, train_gradient) - print(f"Took {time.time() - start_time} seconds.") - - start_time = time.time() - train_gradient = opt_einsum.contract("b...i,b...o->bio", output_gradient, input_activation) - gt_wo_einsum = lr_gradient.reshape(query_batch_dim, -1) @ train_gradient.reshape(batch_dim, -1).T - print(f"Took {time.time() - start_time} seconds.") - - assert torch.allclose(gt, gt_wo_einsum) - - start_time = time.time() - direct1 = opt_einsum.contract("qik,b...i,b...o,qko->qb", left_mat, output_gradient, input_activation, right_mat) - print(f"Took {time.time() - start_time} seconds.") - - start_time = time.time() - direct2 = opt_einsum.contract("qio,b...i,b...o->qb", lr_gradient, output_gradient, input_activation) - print(f"Took {time.time() - start_time} seconds.") - - assert torch.allclose(gt, direct1) - assert torch.allclose(gt, direct2) - - path1 = opt_einsum.contract_path( - "qik,b...i,b...o,qko->qb", left_mat, output_gradient, input_activation, right_mat, optimize="optimal" - ) - path2 = opt_einsum.contract_path( - "qio,b...i,b...o->qb", lr_gradient, output_gradient, input_activation, optimize="optimal" - ) - print(path1) - print(path2) - - print("=" * 80) - - path1 = opt_einsum.contract_path( - "qik,b...i,b...o,qko->qb", left_mat, output_gradient, input_activation, right_mat, optimize="greedy" - ) - path2 = opt_einsum.contract_path( - "qio,b...i,b...o->qb", lr_gradient, output_gradient, input_activation, optimize="greedy" - ) - print(path1) - print(path2) - - print("=" * 80) - - path1 = opt_einsum.contract_path( - "qik,b...i,b...o,qko->qb", left_mat, output_gradient, input_activation, right_mat, optimize="dp" - ) - path2 = opt_einsum.contract_path( - "qio,b...i,b...o->qb", lr_gradient, output_gradient, input_activation, optimize="dp" - ) - print(path1) - print(path2) - - path1 = opt_einsum.contract_path( - "qik,b...i,b...o,qko->qb", - left_mat, - output_gradient, - input_activation, - right_mat, - optimize=DynamicProgramming(search_outer=True, minimize="size"), - ) - path2 = opt_einsum.contract_path( - "qio,b...i,b...o->qb", - lr_gradient, - output_gradient, - input_activation, - optimize=DynamicProgramming(search_outer=True, minimize="size"), - ) - print(path1) - print(path2) - - # path1 = opt_einsum.contract_path("qik,b...i,b...o,qko->qb", left_mat, output_gradient, input_activation, right_mat) - # path2 = opt_einsum.contract_path("qio,b...i,b...o->qb", lr_gradient, output_gradient, input_activation) - # print(path1) - # print(path2) diff --git a/tests/modules/test_svd.py b/tests/modules/test_svd.py deleted file mode 100644 index a81fd52..0000000 --- a/tests/modules/test_svd.py +++ /dev/null @@ -1,229 +0,0 @@ -# pylint: skip-file - -import opt_einsum -import pytest -import torch -from accelerate.utils import set_seed -from opt_einsum import DynamicProgramming - - -def test_query_gradient_svd( - seed: int = 0, -) -> None: - input_dim = 2048 - output_dim = 1024 - batch_dim = 16 - set_seed(seed) - - gradient = torch.rand(size=(batch_dim, output_dim, input_dim), dtype=torch.float64) - - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - assert torch.allclose(gradient, U @ torch.diag_embed(S) @ V, atol=1e-5, rtol=1e-3) - - rank = 32 - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.bmm(left, right).shape == gradient.shape - - rank = input_dim - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.allclose(torch.bmm(left, right), gradient, atol=1e-5, rtol=1e-3) - - rank = 32 - lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank), dtype=torch.float64) - lr_gradient2 = torch.rand(size=(batch_dim, rank, input_dim), dtype=torch.float64) - gradient = torch.bmm(lr_gradient1, lr_gradient2) - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.allclose(torch.bmm(left_mat, right_mat), gradient, atol=1e-5, rtol=1e-3) - - query_batch_dim = 32 - new_gradient = torch.rand(size=(query_batch_dim, output_dim, input_dim), dtype=torch.float64) - score = opt_einsum.contract("toi,qoi->tq", gradient, new_gradient) - - lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - assert torch.allclose(score, lr_score) - - lr_score_reconst_matmul = torch.matmul( - torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() - ) - assert torch.allclose(score, lr_score_reconst_matmul) - - # These should be able to avoid explicit reconstruction. - # This should be used when input_dim > output_dim. - intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) - final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) - assert torch.allclose(score, final) - print("Option 1") - print(intermediate.numel()) - - # This should be used when output_dim > input_dim. - intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) - final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) - assert torch.allclose(score, final2) - print("Option 2") - print(intermediate2.numel()) - - print("Reconstruction") - print((torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel()) - path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - print(path) - - -@pytest.mark.parametrize("input_dim", [256, 512]) -@pytest.mark.parametrize("output_dim", [512, 1024]) -@pytest.mark.parametrize("batch_dim", [8, 16]) -@pytest.mark.parametrize("qbatch_dim", [8, 16]) -@pytest.mark.parametrize("rank", [8, 32]) -@pytest.mark.parametrize("seed", [0]) -def test_query_gradient_svd_reconst( - input_dim: int, - output_dim: int, - batch_dim: int, - qbatch_dim: int, - rank: int, - seed: int, -) -> None: - set_seed(seed) - - lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank + 50), dtype=torch.float64) - lr_gradient2 = torch.rand(size=(batch_dim, rank + 50, input_dim), dtype=torch.float64) - gradient = torch.bmm(lr_gradient1, lr_gradient2) - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left_mat, right_mat = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - - output_gradient = torch.rand(size=(qbatch_dim, output_dim), dtype=torch.float64) - input_activation = torch.rand(size=(qbatch_dim, input_dim), dtype=torch.float64) - new_gradient = opt_einsum.contract("b...i,b...o->bio", output_gradient, input_activation) - - lr_score = opt_einsum.contract("qki,toi,qok->qt", right_mat, new_gradient, left_mat) - lr_score_reconst_matmul = torch.matmul( - torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1), new_gradient.view(new_gradient.shape[0], -1).t() - ) - assert torch.allclose(lr_score, lr_score_reconst_matmul) - - # This should be used when input_dim > output_dim. - intermediate = opt_einsum.contract("qki,toi->qtko", right_mat, new_gradient) - final = opt_einsum.contract("qtko,qok->qt", intermediate, left_mat) - assert torch.allclose(lr_score, final) - print("Option 1") - print(intermediate.numel()) - - # This should be used when output_dim > input_dim. - intermediate2 = torch.einsum("toi,qok->qtik", new_gradient, left_mat) - final2 = opt_einsum.contract("qki,qtik->qt", right_mat, intermediate2) - assert torch.allclose(lr_score, final2) - print("Option 2") - print(intermediate2.numel()) - - print("Reconstruction") - reconst_numel = (torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel() - print(reconst_numel) - - # path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat, optimize="greedy") - # print(path) - - path = opt_einsum.contract_path( - "qki,toi,qok->qt", - right_mat, - new_gradient, - left_mat, - optimize=DynamicProgramming(search_outer=True, minimize="size"), - ) - print(path) - - path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat, optimize="optimal") - print(path) - - print("Direct") - path = opt_einsum.contract_path( - "qik,qko,bi,bo->qb", - left_mat, - right_mat, - output_gradient, - input_activation, - optimize=DynamicProgramming(search_outer=True, minimize="flops"), - ) - direct_result = opt_einsum.contract( - "qik,qko,b...i,b...o->qb", - left_mat, - right_mat, - output_gradient, - input_activation, - optimize=DynamicProgramming(search_outer=True, minimize="flops"), - ) - assert torch.allclose(lr_score, direct_result) - print(path) - - # print("Direct") - # reconst_numel = (torch.matmul(left_mat, right_mat).view(left_mat.size(0), -1)).numel() - # print(reconst_numel) - # path = opt_einsum.contract_path("qki,toi,qok->qt", right_mat, new_gradient, left_mat, optimize="optimal") - # print(path) - - if new_gradient.size(0) * right_mat.size(0) * rank * min((right_mat.size(2), left_mat.size(1))) > right_mat.size( - 0 - ) * right_mat.size(2) * left_mat.size(1): - assert intermediate2.numel() > reconst_numel and intermediate.numel() > reconst_numel - elif output_dim >= input_dim: - assert intermediate2.numel() <= reconst_numel - else: - assert intermediate.numel() <= reconst_numel - - -def test_query_gradient_svd_vs_low_rank_svd( - seed: int = 0, -) -> None: - input_dim = 2048 - output_dim = 1024 - batch_dim = 16 - set_seed(seed) - - rank = 32 - lr_gradient1 = torch.rand(size=(batch_dim, output_dim, rank), dtype=torch.float32) - lr_gradient2 = torch.rand(size=(batch_dim, rank, input_dim), dtype=torch.float32) - gradient = torch.bmm(lr_gradient1, lr_gradient2) - - U, S, V = torch.linalg.svd( - gradient.contiguous(), - full_matrices=False, - ) - - U_k = U[:, :, :rank] - S_k = S[:, :rank] - V_k = V[:, :rank, :].clone() - left, right = torch.matmul(U_k, torch.diag_embed(S_k)).contiguous(), V_k.contiguous() - assert torch.bmm(left, right).shape == gradient.shape - print(f"Error: {(torch.bmm(left, right) - gradient).norm()}") - - new_U, new_S, new_V = torch.svd_lowrank( - gradient.contiguous(), - q=rank, - ) - new_left, new_right = torch.matmul(new_U, torch.diag_embed(new_S)).contiguous(), new_V.transpose(1, 2).contiguous() - assert torch.bmm(new_left, new_right).shape == gradient.shape - print(f"Error: {(torch.bmm(new_left, new_right) - gradient).norm()}") diff --git a/tests/scores/test_pairwise_scores.py b/tests/scores/test_pairwise_scores.py index f442eef..a0a85db 100644 --- a/tests/scores/test_pairwise_scores.py +++ b/tests/scores/test_pairwise_scores.py @@ -656,18 +656,22 @@ def test_query_accumulation_steps( "test_name", [ "mlp", - "repeated_mlp", - "roberta", + # "repeated_mlp", + # "roberta", ], ) @pytest.mark.parametrize("query_size", [50]) @pytest.mark.parametrize("train_size", [32]) +@pytest.mark.parametrize("data_partitions", [3]) +@pytest.mark.parametrize("module_partitions", [3]) @pytest.mark.parametrize("query_gradient_low_rank", [None]) @pytest.mark.parametrize("seed", [8]) def test_query_gradient_aggregation( test_name: str, query_size: int, train_size: int, + data_partitions: int, + module_partitions: int, query_gradient_low_rank: Optional[int], seed: int, ) -> None: @@ -712,6 +716,8 @@ def test_query_gradient_aggregation( scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) score_args.aggregate_query_gradients = True + score_args.data_partitions = data_partitions + score_args.module_partitions = data_partitions analyzer.compute_pairwise_scores( scores_name=custom_scores_name("aggregation"), factors_name=DEFAULT_FACTORS_NAME, @@ -723,13 +729,14 @@ def test_query_gradient_aggregation( score_args=score_args, overwrite_output_dir=True, ) - partitioned_scores = analyzer.load_pairwise_scores( + aggregated_scores = analyzer.load_pairwise_scores( scores_name=custom_scores_name("aggregation"), ) + assert aggregated_scores[ALL_MODULE_NAME].shape[0] == 1 assert torch.allclose( scores[ALL_MODULE_NAME].sum(dim=0, keepdim=True), - partitioned_scores[ALL_MODULE_NAME], + aggregated_scores[ALL_MODULE_NAME], atol=ATOL, rtol=RTOL, ) @@ -739,12 +746,14 @@ def test_query_gradient_aggregation( "test_name", [ "mlp", - "conv_bn", - "gpt", + # "conv_bn", + # "gpt", ], ) @pytest.mark.parametrize("query_size", [64]) @pytest.mark.parametrize("train_size", [32]) +@pytest.mark.parametrize("data_partitions", [3]) +@pytest.mark.parametrize("module_partitions", [2]) @pytest.mark.parametrize("aggregate_query_gradients", [True, False]) @pytest.mark.parametrize("query_gradient_low_rank", [None]) @pytest.mark.parametrize("seed", [9]) @@ -752,6 +761,8 @@ def test_train_gradient_aggregation( test_name: str, query_size: int, train_size: int, + data_partitions: int, + module_partitions: int, aggregate_query_gradients: bool, query_gradient_low_rank: Optional[int], seed: int, @@ -794,6 +805,8 @@ def test_train_gradient_aggregation( scores = analyzer.load_pairwise_scores(scores_name=DEFAULT_SCORES_NAME) score_args.aggregate_train_gradients = True + score_args.data_partitions = data_partitions + score_args.module_partitions = module_partitions analyzer.compute_pairwise_scores( scores_name=custom_scores_name("aggregation"), factors_name=DEFAULT_FACTORS_NAME, @@ -809,6 +822,7 @@ def test_train_gradient_aggregation( scores_name=custom_scores_name("aggregation"), ) + assert aggregated_scores[ALL_MODULE_NAME].shape[1] == 1 assert torch.allclose( scores[ALL_MODULE_NAME].sum(dim=1, keepdim=True), aggregated_scores[ALL_MODULE_NAME], diff --git a/tests/scores/test_self_scores.py b/tests/scores/test_self_scores.py index c0ee190..619f147 100644 --- a/tests/scores/test_self_scores.py +++ b/tests/scores/test_self_scores.py @@ -163,7 +163,8 @@ def test_compute_self_scores_dtype( "conv_bn", ], ) -@pytest.mark.parametrize("strategy", ["identity", "diagonal", "kfac", "ekfac"]) +# @pytest.mark.parametrize("strategy", ["identity", "diagonal", "kfac", "ekfac"]) +@pytest.mark.parametrize("strategy", ["ekfac"]) @pytest.mark.parametrize("train_size", [49]) @pytest.mark.parametrize("seed", [2]) def test_self_scores_batch_size_equivalence( diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index a51b852..e5293ea 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -57,7 +57,7 @@ def test_analyzer( factor_args = FactorArguments(strategy=strategy) if test_name == "repeated_mlp": - factor_args.shared_parameters_exist = True + factor_args.has_shared_parameters = True analyzer.fit_all_factors( factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", dataset=train_dataset, @@ -86,6 +86,7 @@ def test_analyzer( score_args=score_args, overwrite_output_dir=True, ) + score_args.use_measurement_for_self_influence = True analyzer.compute_self_scores( scores_name="self", factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", @@ -102,7 +103,6 @@ def test_default_factor_arguments() -> None: assert factor_args.strategy == "ekfac" assert factor_args.use_empirical_fisher is False - assert factor_args.distributed_sync_interval == 1000 assert factor_args.amp_dtype is None assert factor_args.has_shared_parameters is False @@ -127,7 +127,6 @@ def test_default_score_arguments() -> None: score_args = ScoreArguments() assert score_args.damping_factor == 1e-08 - assert score_args.distributed_sync_interval == 1000 assert score_args.amp_dtype is None assert score_args.offload_activations_to_cpu is False assert score_args.einsum_minimize_size is False diff --git a/tests/testable_tasks/classification.py b/tests/testable_tasks/classification.py index db89a42..1282f2a 100644 --- a/tests/testable_tasks/classification.py +++ b/tests/testable_tasks/classification.py @@ -26,6 +26,18 @@ def make_conv_model(bias: bool = True, seed: int = 0) -> nn.Module: ) +def make_conv_inplace_model(bias: bool = True, seed: int = 0) -> nn.Module: + set_seed(seed) + return nn.Sequential( + nn.Conv2d(3, 4, 3, 1, bias=bias), + nn.ReLU(inplace=True), + nn.Conv2d(4, 8, 3, 1, bias=bias), + nn.ReLU(inplace=True), + nn.Flatten(), + nn.Linear(1152, 5, bias=bias), + ) + + def make_conv_bn_model(bias: bool = True, seed: int = 0) -> nn.Module: set_seed(seed) return nn.Sequential( diff --git a/tests/testable_tasks/language_modeling.py b/tests/testable_tasks/language_modeling.py index e3fb016..4c7f1e3 100644 --- a/tests/testable_tasks/language_modeling.py +++ b/tests/testable_tasks/language_modeling.py @@ -137,7 +137,7 @@ def compute_measurement( ) -> torch.Tensor: return self.compute_train_loss(batch, model) - def tracked_modules(self) -> List[str]: + def get_influence_tracked_modules(self) -> List[str]: total_modules = [] for i in range(5): diff --git a/tests/testable_tasks/multiple_choice.py b/tests/testable_tasks/multiple_choice.py index 52f40de..971b6fb 100644 --- a/tests/testable_tasks/multiple_choice.py +++ b/tests/testable_tasks/multiple_choice.py @@ -80,7 +80,7 @@ def preprocess_function(examples: Any): class MultipleChoiceTask(Task): - do_post_process_per_sample_gradient = True + enable_post_process_per_sample_gradient = True def compute_train_loss( self, diff --git a/tests/utils.py b/tests/utils.py index 696a75c..28caf35 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ WrongClassificationTask, make_classification_dataset, make_conv_bn_model, + make_conv_inplace_model, make_conv_model, ) from tests.testable_tasks.language_modeling import ( @@ -100,6 +101,12 @@ def prepare_test( query_dataset = make_classification_dataset(num_data=query_size, seed=seed + 1) task = ClassificationTask() data_collator = None + elif test_name == "conv_inplace": + model = make_conv_inplace_model(seed=seed) + train_dataset = make_classification_dataset(num_data=train_size, seed=seed) + query_dataset = make_classification_dataset(num_data=query_size, seed=seed + 1) + task = ClassificationTask() + data_collator = None elif test_name == "wrong_conv": model = make_conv_model(seed=seed) train_dataset = make_classification_dataset(num_data=train_size, seed=seed)