Skip to content

Commit

Permalink
Various optimizations done
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 4, 2024
1 parent 2828847 commit 427f6ce
Show file tree
Hide file tree
Showing 45 changed files with 3,025 additions and 2,481 deletions.
98 changes: 50 additions & 48 deletions kronfluence/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,33 @@ def prepare_model(
model: nn.Module,
task: Task,
) -> nn.Module:
"""Prepares the model before passing it to `Analyzer`. This function sets `param.requires_grad = False`
for all modules and installs `TrackedModule` to supported modules. This `TrackedModule` keeps track of relevant
statistics needed to compute influence scores.
"""Prepares the model for analysis by setting all parameters and buffers to non-trainable
and installing `TrackedModule` wrappers on supported modules.
Args:
model (nn.Module):
The PyTorch model to be analyzed.
The PyTorch model to be prepared for analysis.
task (Task):
The specific task associated with the model.
The specific task associated with the model, used for `TrackedModule` installation.
Returns:
nn.Module:
The PyTorch model with `param.requires_grad = False` on all modules and with `TrackedModule` installed.
The prepared model with non-trainable parameters and `TrackedModule` wrappers.
"""
model.eval()
for params in model.parameters():
params.requires_grad = False
for buffers in model.buffers():
buffers.requires_grad = False
# Install `TrackedModule` to the model.

# Install `TrackedModule` wrappers on supported modules.
model = wrap_tracked_modules(model=model, task=task)
return model


class Analyzer(FactorComputer, ScoreComputer):
"""Handles the computation of all factors (e.g., covariance and Lambda matrices for EKFAC)
and influence scores for a given PyTorch model."""
"""Handles the computation of factors (e.g., covariance and Lambda matrices for EKFAC) and
influence scores for a given PyTorch model."""

def __init__(
self,
Expand All @@ -61,33 +61,33 @@ def __init__(
output_dir: str = "./influence_results",
disable_model_save: bool = True,
) -> None:
"""Initializes an instance of the Analyzer class.
"""Initializes an instance of the `Analyzer` class.
Args:
analysis_name (str):
The unique identifier for the analysis, used to organize and retrieve the results.
Unique identifier for the analysis, used for organizing results.
model (nn.Module):
The PyTorch model to be analyzed.
task (Task):
The specific task associated with the model.
cpu (bool, optional):
Specifies whether the analysis should be explicitly performed using the CPU.
Defaults to False, utilizing GPU resources if available.
If `True`, forces analysis to be performed on CPU. Defaults to `False`.
log_level (int, optional):
The logging level to use (e.g., logging.DEBUG, logging.INFO). Defaults to the root logging level.
Logging level (e.g., logging.DEBUG, logging.INFO). Defaults to root logging level.
log_main_process_only (bool, optional):
If True, restricts logging to the main process. Defaults to True.
If `True`, restricts logging to the main process. Defaults to `True`.
profile (bool, optional):
Enables the generation of performance profiling logs. This can be useful for
identifying bottlenecks or performance issues. Defaults to False.
If `True`, enables performance profiling logs. Defaults to `False`.
disable_tqdm (bool, optional):
Disables TQDM progress bars. Defaults to False.
If `True`, disables TQDM progress bars. Defaults to `False`.
output_dir (str):
The file path to the directory, where analysis results will be stored. If the directory
does not exist, it will be created. Defaults to './influence_results'.
Directory path for storing analysis results. Defaults to './influence_results'.
disable_model_save (bool, optional):
If set to True, prevents the saving of the model's state_dict. When the provided model is different
from the previously saved model, it will raise an Exception. Defaults to True.
If `True`, prevents saving the model's state_dict. Defaults to `True`.
Raises:
ValueError:
If the provided model differs from a previously saved model when `disable_model_save` is `False`.
"""
super().__init__(
name=analysis_name,
Expand All @@ -103,17 +103,17 @@ def __init__(
self.logger.info(f"Initializing Computer with parameters: {locals()}")
self.logger.debug(f"Process state configuration:\n{repr(self.state)}")

# Saves model parameters.
# Save model parameters if necessary.
if self.state.is_main_process and not disable_model_save:
self._save_model()
self.state.wait_for_everyone()

def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None:
"""Sets the default DataLoader parameters to use for all DataLoaders.
"""Sets the default DataLoader arguments.
Args:
dataloader_kwargs (DataLoaderKwargs):
The object containing parameters for DataLoader.
The object containing arguments for DataLoader.
"""
self._dataloader_params = dataloader_kwargs

Expand All @@ -125,7 +125,7 @@ def _save_model(self) -> None:

if model_save_path.exists():
self.logger.info(f"Found existing saved model at `{model_save_path}`.")
# Load the existing model's state_dict for comparison.
# Load existing model's `state_dict` for comparison.
loaded_state_dict = load_file(model_save_path)
if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()):
error_msg = (
Expand All @@ -152,27 +152,24 @@ def fit_all_factors(
factor_args: Optional[FactorArguments] = None,
overwrite_output_dir: bool = False,
) -> None:
"""Computes all necessary factors for the given factor strategy. As an example, EK-FAC
requires (1) computing covariance matrices, (2) performing Eigendecomposition, and
(3) computing Lambda (corrected eigenvalues) matrices.
"""Computes all necessary factors for the given factor strategy.
Args:
factors_name (str):
The unique identifier for the factor, used to organize and retrieve the results.
Unique identifier for the factor, used for organizing results.
dataset (data.Dataset):
The dataset that will be used to fit all the factors.
Dataset used to fit all the factors.
per_device_batch_size (int, optional):
The per-device batch size used to fit the factors. If not specified, executable
per-device batch size is automatically determined.
Per-device batch size for factor fitting. If not specified, executable per-device batch size
is automatically determined.
initial_per_device_batch_size_attempt (int):
The initial attempted per-device batch size when the batch size is not provided.
Initial batch size attempt when `per_device_batch_size` is not explicitly provided. Defaults to `4096`.
dataloader_kwargs (DataLoaderKwargs, optional):
Controls additional arguments for PyTorch's DataLoader.
Additional arguments for PyTorch's DataLoader.
factor_args (FactorArguments, optional):
Arguments related to computing the factors. If not specified,
the default values of `FactorArguments` will be used.
Arguments for factor computation. Defaults to `FactorArguments` default values.
overwrite_output_dir (bool, optional):
If True, the existing factors with the same name will be overwritten.
If `True`, overwrites existing factors with the same name. Defaults to `False`.
"""
self.fit_covariance_matrices(
factors_name=factors_name,
Expand Down Expand Up @@ -200,36 +197,41 @@ def fit_all_factors(

@staticmethod
def load_file(path: Union[str, Path]) -> Dict[str, torch.Tensor]:
"""Loads the `.safetensors` file at the given path from disk.
See https://github.com/huggingface/safetensors.
"""Loads a `safetensors` file from the given path.
Args:
path (Path):
The path to the `.safetensors` file.
The path to the `safetensors` file.
Returns:
Dict[str, torch.Tensor]:
The contents of the file, which is the dictionary mapping string to tensors.
Dictionary mapping strings to tensors from the loaded file.
Raises:
FileNotFoundError:
If the specified file does not exist.
Note:
For more information on safetensors, see https://github.com/huggingface/safetensors.
"""
if isinstance(path, str):
path = Path(path).resolve()
if not path.exists():
raise FileNotFoundError(f"File does not exists at `{path}`.")
raise FileNotFoundError(f"File not found: {path}.")
return load_file(path)

@staticmethod
def get_module_summary(model: nn.Module) -> str:
"""Returns the formatted summary of the modules in model. Useful identifying which modules to
compute influence scores.
"""Generates a formatted summary of the model's modules, excluding those without parameters. This summary is
useful for identifying which modules to compute influence scores for.
Args:
model (nn.Module):
The PyTorch model to be investigated.
The PyTorch model to be summarized.
Returns:
str:
The formatted string summary of the model.
A formatted string containing the model summary, including module names and representations.
"""
format_str = "==Model Summary=="
for module_name, module in model.named_modules():
Expand Down
Loading

0 comments on commit 427f6ce

Please sign in to comment.