diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index a1723b5..d1e8b20 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -39,7 +39,7 @@ find_executable_batch_size, make_indices_partition, ) -from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError +from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError, TrackedModuleNotFoundError from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger, get_time from kronfluence.utils.save import ( FACTOR_ARGUMENTS_NAME, @@ -80,15 +80,16 @@ def __init__( self.model.eval() self.task = task - tracked_module_names = get_tracked_module_names(self.model) - if len(tracked_module_names) == 0: + try: + tracked_module_names = get_tracked_module_names(self.model) + except TrackedModuleNotFoundError as e: error_msg = ( f"No tracked modules found in the provided model: {self.model}. " f"Please make sure to run `prepare_model` before passing it in to the " f"Analyzer." ) self.logger.error(error_msg) - raise UnsupportableModuleError(error_msg) + raise UnsupportableModuleError(error_msg) from e self.logger.info(f"Tracking modules with names: {tracked_module_names}.") if self.state.use_distributed and not isinstance(model, (DDP, FSDP)): diff --git a/kronfluence/computer/covariance_computer.py b/kronfluence/computer/covariance_computer.py index dd3dacc..8eeda6b 100644 --- a/kronfluence/computer/covariance_computer.py +++ b/kronfluence/computer/covariance_computer.py @@ -170,7 +170,7 @@ def fit_covariance_matrices( if no_partition: if total_data_examples < self.state.num_processes: - error_msg = "There are more data examples than the number of processes." + error_msg = "The number of processes are more than the data examples." self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: @@ -237,7 +237,7 @@ def fit_covariance_matrices( max_total_examples = total_data_examples // factor_args.covariance_data_partition_size if max_total_examples < self.state.num_processes: - error_msg = "There are more data examples than the number of processes." + error_msg = "The number of processes are more than the data examples." self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: @@ -298,10 +298,10 @@ def aggregate_covariance_matrices( data_partition_size = factor_args.covariance_data_partition_size module_partition_size = factor_args.covariance_module_partition_size all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] - all_partition_exists = [ + all_partition_exists = all( covariance_matrices_exist(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions - ] + ) if not all_partition_exists: self.logger.info( "Covariance matrices are not aggregated as covariance matrices for some partitions " diff --git a/kronfluence/computer/eigen_computer.py b/kronfluence/computer/eigen_computer.py index a9b7baf..840cdaf 100644 --- a/kronfluence/computer/eigen_computer.py +++ b/kronfluence/computer/eigen_computer.py @@ -288,7 +288,7 @@ def fit_lambda_matrices( if no_partition: if total_data_examples < self.state.num_processes: - error_msg = "There are more data examples than the number of processes." + error_msg = "The number of processes are more than the data examples." self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: @@ -354,7 +354,7 @@ def fit_lambda_matrices( max_total_examples = total_data_examples // factor_args.lambda_data_partition_size if max_total_examples < self.state.num_processes: - error_msg = "There are more data examples than the number of processes." + error_msg = "The number of processes are more than the data examples." self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index e9684a5..ab13b07 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -293,7 +293,7 @@ def fit_lambda_matrices_with_loader( with torch.no_grad(): saved_factors: FACTOR_TYPE = {} - for covariance_factor_name in LAMBDA_FACTOR_NAMES: - saved_factors[covariance_factor_name] = load_factors(model=model, factor_name=covariance_factor_name) + for lambda_factor_name in LAMBDA_FACTOR_NAMES: + saved_factors[lambda_factor_name] = load_factors(model=model, factor_name=lambda_factor_name) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) return num_data_processed, saved_factors diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 6158e36..285c899 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -71,7 +71,7 @@ def _compute_per_sample_gradient( Returns: torch.Tensor: The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix - with dimension `batch_size x input_dim x gradient_dim`. An additional dimension is added + with dimension `batch_size x gradient_dim x input_dim`. An additional dimension is added when the bias term is used. """ if self.original_module.bias is not None: diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 31f1680..30523ef 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -56,10 +56,9 @@ class TrackedModule(nn.Module): SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {} - def __init_subclass__(cls, module_type: Optional[Type[nn.Module]] = None, **kwargs) -> None: + def __init_subclass__(cls, module_type: Type[nn.Module] = None, **kwargs) -> None: """Automatically registers subclasses as supported modules.""" super().__init_subclass__(**kwargs) - assert module_type is not None if module_type is not None: cls.SUPPORTED_MODULES[module_type] = cls @@ -317,7 +316,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _release_covariance_matrices(self) -> None: """Clears the stored activation and pseudo-gradient covariance matrices from memory.""" @@ -492,7 +493,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _release_lambda_matrix(self) -> None: """Clears the stored Lambda matrix from memory.""" @@ -608,7 +611,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _release_preconditioned_gradient(self) -> None: """Clears the preconditioned per-sample-gradient from memory.""" @@ -728,7 +733,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def _register_self_score_hooks(self) -> None: """Installs forward and backward hooks for computation of self-influence scores.""" @@ -786,7 +793,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) + self._registered_hooks.append( + self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook) + ) def release_scores(self) -> None: """Clears the influence scores from memory."""