Skip to content

Commit

Permalink
Refactor score computations
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 21, 2024
1 parent 9aed5cc commit b1e78e0
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 214 deletions.
4 changes: 4 additions & 0 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class ScoreArguments(Arguments):
default=None,
metadata={"help": "Dtype for automatic mixed precision (AMP). Disables AMP if None."},
)
shared_parameters_exist: bool = field(
default=False,
metadata={"help": "Specifies whether the shared parameters exist in the forward pass."},
)

# Partition configuration. #
data_partition_size: int = field(
Expand Down
12 changes: 8 additions & 4 deletions kronfluence/computer/score_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _configure_and_save_score_args(
factors_name: str,
overwrite_output_dir: bool,
) -> Tuple[FactorArguments, ScoreArguments]:
"""Configure the provided factor arguments and save it in disk."""
"""Configures the provided factor arguments and saves it in disk."""
if score_args is None:
score_args = ScoreArguments()
self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.")
Expand Down Expand Up @@ -173,7 +173,7 @@ def _find_executable_pairwise_scores_batch_size(

def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Release all memory that could be caused by the previous OOM.
# Releases all memory that could be caused by the previous OOM.
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
total_batch_size = batch_size * self.state.num_processes
Expand Down Expand Up @@ -203,6 +203,7 @@ def executable_batch_size_func(batch_size: int) -> None:
train_loader=train_loader,
per_device_query_batch_size=per_device_query_batch_size,
tracked_module_names=tracked_modules_name,
disable_tqdm=True
)

per_device_batch_size = find_executable_batch_size(
Expand Down Expand Up @@ -403,6 +404,7 @@ def compute_pairwise_scores(
score_args=score_args,
factor_args=factor_args,
tracked_module_names=module_partition_names[module_partition],
disable_tqdm=self.disable_tqdm,
)
end_time = get_time(state=self.state)
elapsed_time = end_time - start_time
Expand Down Expand Up @@ -487,7 +489,7 @@ def _find_executable_self_scores_batch_size(

def executable_batch_size_func(batch_size: int) -> None:
self.logger.info(f"Attempting to set per-device batch size to {batch_size}.")
# Release all memory that could be caused by the previous OOM.
# Releases all memory that could be caused by the previous OOM.
set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False)
release_memory()
total_batch_size = batch_size * self.state.num_processes
Expand All @@ -512,6 +514,7 @@ def executable_batch_size_func(batch_size: int) -> None:
score_args=score_args,
factor_args=factor_args,
tracked_module_names=tracked_modules_name,
disable_tqdm=True,
)

per_device_batch_size = find_executable_batch_size(
Expand All @@ -536,7 +539,7 @@ def compute_self_scores(
overwrite_output_dir: bool = False,
) -> Optional[SCORE_TYPE]:
"""Computes self-influence scores for the given score configuration. As an example,
for T training dataset, the self-influence scores are represented as T-dimensional vector.
for training dataset with T examples, the self-influence scores are represented as T-dimensional vector.
Args:
scores_name (str):
Expand Down Expand Up @@ -691,6 +694,7 @@ def compute_self_scores(
score_args=score_args,
factor_args=factor_args,
tracked_module_names=module_partition_names[module_partition],
disable_tqdm=self.disable_tqdm,
)
end_time = get_time(state=self.state)
elapsed_time = end_time - start_time
Expand Down
5 changes: 5 additions & 0 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from kronfluence.arguments import FactorArguments
from kronfluence.module.tracked_module import ModuleMode
from kronfluence.module.utils import (
get_tracked_module_names,
load_factors,
remove_gradient_scale,
set_attention_mask,
set_gradient_scale,
set_mode,
Expand Down Expand Up @@ -131,6 +133,8 @@ def fit_covariance_matrices_with_loader(
"""
with torch.no_grad():
update_factor_args(model=model, factor_args=factor_args)
if tracked_module_names is None:
tracked_module_names = get_tracked_module_names(model=model)
set_mode(
model=model,
tracked_module_names=tracked_module_names,
Expand Down Expand Up @@ -197,6 +201,7 @@ def fit_covariance_matrices_with_loader(

# Clean up the memory.
model.zero_grad(set_to_none=True)
remove_gradient_scale(model=model)
set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False)

return num_data_processed, saved_factors
3 changes: 3 additions & 0 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def fit_lambda_matrices_with_loader(
"""
with torch.no_grad():
update_factor_args(model=model, factor_args=factor_args)
if tracked_module_names is None:
tracked_module_names = get_tracked_module_names(model=model)
set_mode(
model=model,
tracked_module_names=tracked_module_names,
Expand Down Expand Up @@ -344,6 +346,7 @@ def fit_lambda_matrices_with_loader(

# Clean up the memory.
model.zero_grad(set_to_none=True)
remove_gradient_scale(model=model)
set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False)

return num_data_processed, saved_factors
Loading

0 comments on commit b1e78e0

Please sign in to comment.