Skip to content

Commit

Permalink
Merge pull request #2 from pomonam/code-review
Browse files Browse the repository at this point in the history
Minor documentation and bug fixes
  • Loading branch information
pomonam authored Mar 16, 2024
2 parents bc56093 + 3737318 commit 0f2075a
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 20 deletions.
9 changes: 5 additions & 4 deletions kronfluence/computer/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
find_executable_batch_size,
make_indices_partition,
)
from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError
from kronfluence.utils.exceptions import FactorsNotFoundError, UnsupportableModuleError, TrackedModuleNotFoundError
from kronfluence.utils.logger import PassThroughProfiler, Profiler, get_logger, get_time
from kronfluence.utils.save import (
FACTOR_ARGUMENTS_NAME,
Expand Down Expand Up @@ -80,15 +80,16 @@ def __init__(
self.model.eval()
self.task = task

tracked_module_names = get_tracked_module_names(self.model)
if len(tracked_module_names) == 0:
try:
tracked_module_names = get_tracked_module_names(self.model)
except TrackedModuleNotFoundError as e:
error_msg = (
f"No tracked modules found in the provided model: {self.model}. "
f"Please make sure to run `prepare_model` before passing it in to the "
f"Analyzer."
)
self.logger.error(error_msg)
raise UnsupportableModuleError(error_msg)
raise UnsupportableModuleError(error_msg) from e
self.logger.info(f"Tracking modules with names: {tracked_module_names}.")

if self.state.use_distributed and not isinstance(model, (DDP, FSDP)):
Expand Down
8 changes: 4 additions & 4 deletions kronfluence/computer/covariance_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def fit_covariance_matrices(

if no_partition:
if total_data_examples < self.state.num_processes:
error_msg = "There are more data examples than the number of processes."
error_msg = "The number of processes are more than the data examples."
self.logger.error(error_msg)
raise ValueError(error_msg)
if per_device_batch_size is None:
Expand Down Expand Up @@ -237,7 +237,7 @@ def fit_covariance_matrices(

max_total_examples = total_data_examples // factor_args.covariance_data_partition_size
if max_total_examples < self.state.num_processes:
error_msg = "There are more data examples than the number of processes."
error_msg = "The number of processes are more than the data examples."
self.logger.error(error_msg)
raise ValueError(error_msg)
if per_device_batch_size is None:
Expand Down Expand Up @@ -298,10 +298,10 @@ def aggregate_covariance_matrices(
data_partition_size = factor_args.covariance_data_partition_size
module_partition_size = factor_args.covariance_module_partition_size
all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)]
all_partition_exists = [
all_partition_exists = all(
covariance_matrices_exist(output_dir=factors_output_dir, partition=partition)
for partition in all_required_partitions
]
)
if not all_partition_exists:
self.logger.info(
"Covariance matrices are not aggregated as covariance matrices for some partitions "
Expand Down
4 changes: 2 additions & 2 deletions kronfluence/computer/eigen_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def fit_lambda_matrices(

if no_partition:
if total_data_examples < self.state.num_processes:
error_msg = "There are more data examples than the number of processes."
error_msg = "The number of processes are more than the data examples."
self.logger.error(error_msg)
raise ValueError(error_msg)
if per_device_batch_size is None:
Expand Down Expand Up @@ -354,7 +354,7 @@ def fit_lambda_matrices(

max_total_examples = total_data_examples // factor_args.lambda_data_partition_size
if max_total_examples < self.state.num_processes:
error_msg = "There are more data examples than the number of processes."
error_msg = "The number of processes are more than the data examples."
self.logger.error(error_msg)
raise ValueError(error_msg)
if per_device_batch_size is None:
Expand Down
4 changes: 2 additions & 2 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def fit_lambda_matrices_with_loader(

with torch.no_grad():
saved_factors: FACTOR_TYPE = {}
for covariance_factor_name in LAMBDA_FACTOR_NAMES:
saved_factors[covariance_factor_name] = load_factors(model=model, factor_name=covariance_factor_name)
for lambda_factor_name in LAMBDA_FACTOR_NAMES:
saved_factors[lambda_factor_name] = load_factors(model=model, factor_name=lambda_factor_name)
set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False)
return num_data_processed, saved_factors
2 changes: 1 addition & 1 deletion kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _compute_per_sample_gradient(
Returns:
torch.Tensor:
The per-sample-gradient tensor. The per-sample-gradient is a 3-dimensional matrix
with dimension `batch_size x input_dim x gradient_dim`. An additional dimension is added
with dimension `batch_size x gradient_dim x input_dim`. An additional dimension is added
when the bias term is used.
"""
if self.original_module.bias is not None:
Expand Down
23 changes: 16 additions & 7 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ class TrackedModule(nn.Module):

SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {}

def __init_subclass__(cls, module_type: Optional[Type[nn.Module]] = None, **kwargs) -> None:
def __init_subclass__(cls, module_type: Type[nn.Module] = None, **kwargs) -> None:
"""Automatically registers subclasses as supported modules."""
super().__init_subclass__(**kwargs)
assert module_type is not None
if module_type is not None:
cls.SUPPORTED_MODULES[module_type] = cls

Expand Down Expand Up @@ -317,7 +316,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _release_covariance_matrices(self) -> None:
"""Clears the stored activation and pseudo-gradient covariance matrices from memory."""
Expand Down Expand Up @@ -492,7 +493,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _release_lambda_matrix(self) -> None:
"""Clears the stored Lambda matrix from memory."""
Expand Down Expand Up @@ -608,7 +611,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _release_preconditioned_gradient(self) -> None:
"""Clears the preconditioned per-sample-gradient from memory."""
Expand Down Expand Up @@ -728,7 +733,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def _register_self_score_hooks(self) -> None:
"""Installs forward and backward hooks for computation of self-influence scores."""
Expand Down Expand Up @@ -786,7 +793,9 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

if self.factor_args.immediate_gradient_removal:
self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook))
self._registered_hooks.append(
self.original_module.register_full_backward_hook(full_backward_gradient_removal_hook)
)

def release_scores(self) -> None:
"""Clears the influence scores from memory."""
Expand Down

0 comments on commit 0f2075a

Please sign in to comment.