Skip to content

Commit

Permalink
Finalize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 19, 2024
1 parent d940aa3 commit 3554110
Show file tree
Hide file tree
Showing 23 changed files with 416 additions and 92 deletions.
54 changes: 45 additions & 9 deletions kronfluence/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import torch
from accelerate.utils import extract_model_from_parallel
from computer.factor_computer import FactorComputer
from computer.score_computer import ScoreComputer
from safetensors.torch import save_file
from torch import nn
from torch.utils import data

from kronfluence.arguments import FactorArguments
from kronfluence.computer.factor_computer import FactorComputer
from kronfluence.computer.score_computer import ScoreComputer
from kronfluence.module.utils import wrap_tracked_modules
from kronfluence.task import Task
from kronfluence.utils.dataset import DataLoaderKwargs
Expand All @@ -29,6 +29,11 @@ def prepare_model(
The PyTorch model to be analyzed.
task (Task):
The specific task associated with the model.
Returns:
nn.Module:
The same PyTorch model with `param.requires_grad = False` on all modules that does not require influence
computations and with `TrackedModule` installed.
"""
model.eval()
for params in model.parameters():
Expand Down Expand Up @@ -79,7 +84,8 @@ def __init__(
The file path to the directory, where analysis results will be stored. If the directory
does not exist, it will be created. Defaults to './analyses'.
disable_model_save (bool, optional):
If set to True, prevents the saving of the model state. Defaults to True.
If set to True, prevents the saving of the model's state_dict. When the provided model is different
from the previously saved model, it will raise an Exception. Defaults to True.
"""
super().__init__(
name=analysis_name,
Expand Down Expand Up @@ -108,10 +114,11 @@ def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None:
"""
self._dataloader_params = dataloader_kwargs

@torch.no_grad()
def _save_model(self) -> None:
"""Saves the model to the output directory."""
model_save_path = self.output_dir / "model.safetensors"
extracted_model = extract_model_from_parallel(self.model)
extracted_model = extract_model_from_parallel(model=self.model, keep_fp32_wrapper=True)

if model_save_path.exists():
self.logger.info(f"Found existing saved model at `{model_save_path}`.")
Expand All @@ -120,13 +127,13 @@ def _save_model(self) -> None:
if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()):
error_msg = (
"Detected a difference between the current model and the one saved at "
f"{model_save_path}. Consider using a different `analysis_name` to "
f"`{model_save_path}`. Consider using a different `analysis_name` to "
f"avoid conflicts."
)
self.logger.error(error_msg)
raise ValueError(error_msg)
else:
self.logger.info(f"No existing model found at {model_save_path}.")
self.logger.info(f"No existing model found at `{model_save_path}`.")
state_dict = extracted_model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
save_file(state_dict, model_save_path)
Expand Down Expand Up @@ -189,12 +196,41 @@ def fit_all_factors(
)

@staticmethod
def load_file(path: Path) -> Optional[Dict[str, torch.Tensor]]:
"""Loads the `.safetensor` file at the given path from disk."""
def load_file(path: Path) -> Dict[str, torch.Tensor]:
"""Loads the `.safetensors` file at the given path from disk.
See https://github.com/huggingface/safetensors.
Args:
path (Path):
The path to the `.safetensors` file.
Returns:
Dict[str, torch.Tensor]:
The contents of the file, which is the dictionary mapping string to tensors.
"""
if not path.exists():
raise FileNotFoundError(f"File does not exists at `{path}`.")
return load_file(path)

@staticmethod
def get_module_summary(model: nn.Module) -> str:
pass
"""Returns the formatted summary of the modules in model. Useful identifying which modules to
compute influence scores.
Args:
model (nn.Module):
The PyTorch model to be investigated.
Returns:
str:
The formatted string summary of the model.
"""
format_str = "==Model Summary=="
for module_name, module in model.named_modules():
if len(list(module.children())) > 0:
continue
if len(list(module.parameters())) == 0:
continue
format_str += f"\nModule Name: `{module_name}`, Module: {repr(module)}"
return format_str
28 changes: 15 additions & 13 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import copy
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional

import torch

from kronfluence.factor.config import FactorStrategy


@dataclass
class Arguments:
Expand All @@ -31,7 +29,8 @@ def to_str_dict(self) -> Dict[str, str]:
class FactorArguments(Arguments):
"""Arguments for computing influence factors."""

strategy: Union[FactorStrategy, str] = field(
# General configuration. #
strategy: str = field(
default="ekfac",
metadata={"help": "Strategy for computing preconditioning factors."},
)
Expand Down Expand Up @@ -152,6 +151,7 @@ class FactorArguments(Arguments):
class ScoreArguments(Arguments):
"""Arguments for computing influence scores."""

# General configuration. #
damping: Optional[float] = field(
default=None,
metadata={
Expand All @@ -163,13 +163,21 @@ class ScoreArguments(Arguments):
default=False,
metadata={"help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook."},
)
cached_activation_cpu_offload: bool = field(
default=False,
metadata={
"help": "Whether to offload cached activation to CPU for computing the "
"per-sample-gradient. This is helpful when the available GPU memory is limited."
},
)
distributed_sync_steps: int = field(
default=1_000,
metadata={
"help": "Specifies the total iteration step to synchronize the process when using distributed setting."
},
)

# Partition configuration. #
data_partition_size: int = field(
default=1,
metadata={
Expand All @@ -187,34 +195,28 @@ class ScoreArguments(Arguments):
},
)

# Score configuration. #
per_module_score: bool = field(
default=False,
metadata={
"help": "Whether to obtain per-module scores instead of the summed scores across all modules. "
"This is useful when performing layer-wise influence analysis."
},
)

query_gradient_rank: Optional[int] = field(
default=None,
metadata={"help": "Rank for the query gradient. Does not apply low-rank approximation if None."},
)

# Dtype configuration. #
query_gradient_svd_dtype: torch.dtype = field(
default=torch.float64,
metadata={"help": "Dtype for performing singular value decomposition (SVD) on the query gradient."},
)

score_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing and storing influence scores."},
)
cached_activation_cpu_offload: bool = field(
default=False,
metadata={
"help": "Whether to offload cached activation to CPU for computing the "
"per-sample-gradient. This is helpful when the available GPU memory is limited."
},
)
per_sample_gradient_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing per-sample-gradients."},
Expand Down
10 changes: 5 additions & 5 deletions kronfluence/computer/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.utils.data import DistributedSampler, SequentialSampler

from kronfluence.arguments import Arguments, FactorArguments, ScoreArguments
from kronfluence.factor.config import FactorConfig
from kronfluence.factor.covariance import (
covariance_matrices_exist,
load_covariance_matrices,
Expand All @@ -24,10 +23,7 @@
load_lambda_matrices,
)
from kronfluence.module.constants import FACTOR_TYPE, SCORE_TYPE
from kronfluence.module.utils import (
get_tracked_module_names,
make_modules_partition,
)
from kronfluence.module.utils import get_tracked_module_names, make_modules_partition
from kronfluence.score.pairwise import load_pairwise_scores, pairwise_scores_exist
from kronfluence.score.self import load_self_scores, self_scores_exist
from kronfluence.task import Task
Expand Down Expand Up @@ -379,6 +375,10 @@ def load_self_scores(self, scores_name: str) -> Optional[SCORE_TYPE]:

def load_all_factors(self, factors_name: str) -> FACTOR_TYPE:
"""Loads factors from disk."""
from kronfluence.factor.config import (
FactorConfig, # pylint: disable=import-outside-toplevel
)

factor_args = self.load_factor_args(factors_name)
factors_output_dir = self.factors_output_dir(factors_name=factors_name)
if factor_args is None:
Expand Down
27 changes: 20 additions & 7 deletions kronfluence/computer/score_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _aggregate_scores(
),
dim=dim,
)
save_fnc(output_dir=scores_output_dir, scores=aggregated_scores)
save_fnc(output_dir=scores_output_dir, scores=aggregated_scores, metadata=score_args.to_str_dict())
end_time = time.time()
elapsed_time = end_time - start_time
self.logger.info(f"Aggregated all partitioned scores in {elapsed_time:.2f} seconds.")
Expand Down Expand Up @@ -324,7 +324,7 @@ def compute_pairwise_scores(
data_partition_size=score_args.data_partition_size,
target_data_partitions=target_data_partitions,
)
max_partition_examples = len(train_dataset) // factor_args.covariance_data_partition_size
max_partition_examples = len(train_dataset) // score_args.data_partition_size
module_partition_names, target_module_partitions = self._get_module_partition(
module_partition_size=score_args.module_partition_size,
target_module_partitions=target_module_partitions,
Expand Down Expand Up @@ -428,7 +428,13 @@ def compute_pairwise_scores(

@torch.no_grad()
def aggregate_pairwise_scores(self, scores_name: str) -> None:
"""Aggregates pairwise scores computed for all data and module partitions."""
"""Aggregates all partitioned pairwise scores. The scores will not be aggregated if scores
for some data or module partitions are missing.
Args:
scores_name (str):
The unique identifier for the score, used to organize and retrieve the results.
"""
score_args = self.load_score_args(scores_name=scores_name)
if score_args is None:
error_msg = (
Expand Down Expand Up @@ -607,7 +613,7 @@ def compute_self_scores(
data_partition_size=score_args.data_partition_size,
target_data_partitions=target_data_partitions,
)
max_partition_examples = len(train_dataset) // factor_args.covariance_data_partition_size
max_partition_examples = len(train_dataset) // score_args.data_partition_size
module_partition_names, target_module_partitions = self._get_module_partition(
module_partition_size=score_args.module_partition_size,
target_module_partitions=target_module_partitions,
Expand Down Expand Up @@ -683,6 +689,7 @@ def compute_self_scores(
output_dir=scores_output_dir,
scores=scores,
partition=partition,
metadata=score_args.to_str_dict(),
)
self.state.wait_for_everyone()
del scores, train_loader
Expand All @@ -699,8 +706,14 @@ def compute_self_scores(
self._log_profile_summary()

@torch.no_grad()
def aggregate_self_scores(self, scores_name: str) -> Optional[SCORE_TYPE]:
"""Aggregates self-influence scores computed for all data and module partitions."""
def aggregate_self_scores(self, scores_name: str) -> None:
"""Aggregates all partitioned self-influence scores. The scores will not be aggregated if scores
for some data or module partitions are missing.
Args:
scores_name (str):
The unique identifier for the score, used to organize and retrieve the results.
"""
score_args = self.load_score_args(scores_name=scores_name)
if score_args is None:
error_msg = (
Expand All @@ -710,7 +723,7 @@ def aggregate_self_scores(self, scores_name: str) -> Optional[SCORE_TYPE]:
self.logger.error(error_msg)
raise ValueError(error_msg)

return self._aggregate_scores(
self._aggregate_scores(
scores_name=scores_name,
score_args=score_args,
exists_fnc=self_scores_exist,
Expand Down
24 changes: 16 additions & 8 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _covariance_matrices_available(self) -> bool:
@torch.no_grad()
def synchronize_covariance_matrices(self) -> None:
"""Aggregates covariance matrices across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and torch.cuda.is_available():
if dist.is_initialized() and torch.cuda.is_available() and self._covariance_matrices_available():
# Note that only the main process holds the aggregated covariance matrix.
for covariance_factor_name in COVARIANCE_FACTOR_NAMES:
dist.reduce(
Expand Down Expand Up @@ -519,6 +519,7 @@ def _release_lambda_matrix(self) -> None:
del self._storage[lambda_factor_name]
self._storage[lambda_factor_name] = None
self._cached_activations = []
del self._cached_per_sample_gradient
self._cached_per_sample_gradient = None

def _lambda_matrix_available(self) -> bool:
Expand All @@ -531,7 +532,7 @@ def _lambda_matrix_available(self) -> bool:
@torch.no_grad()
def synchronize_lambda_matrices(self) -> None:
"""Aggregates Lambda matrices across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and torch.cuda.is_available():
if dist.is_initialized() and torch.cuda.is_available() and self._lambda_matrix_available():
# Note that only the main process holds the aggregated Lambda matrix.
for lambda_factor_name in LAMBDA_FACTOR_NAMES:
torch.distributed.reduce(
Expand Down Expand Up @@ -559,16 +560,17 @@ def _compute_low_rank_preconditioned_gradient(
Low-rank matrices that approximate the original preconditioned gradient.
"""
U, S, V = torch.linalg.svd( # pylint: disable=not-callable
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype)
preconditioned_gradient.contiguous().to(dtype=self.score_args.query_gradient_svd_dtype),
full_matrices=False,
)
rank = self.score_args.query_gradient_rank
U_k = U[:, :, :rank]
S_k = S[:, :rank]
# Avoid holding the full memory of the original tensor before indexing.
V_k = V[:, :, :rank].clone()
V_k = V[:, :rank, :].clone()
return [
torch.matmul(U_k, torch.diag_embed(S_k)).contiguous().to(dtype=self.score_args.score_dtype),
torch.transpose(V_k, 1, 2).contiguous().to(dtype=self.score_args.score_dtype),
torch.matmul(U_k, torch.diag_embed(S_k)).to(dtype=self.score_args.score_dtype).contiguous(),
V_k.to(dtype=self.score_args.score_dtype).contiguous(),
]

def _register_precondition_gradient_hooks(self) -> None:
Expand Down Expand Up @@ -640,6 +642,7 @@ def _release_preconditioned_gradient(self) -> None:
del self._storage[PRECONDITIONED_GRADIENT_NAME]
self._storage[PRECONDITIONED_GRADIENT_NAME] = None
self._cached_activations = []
del self._cached_per_sample_gradient
self._cached_per_sample_gradient = None

def get_preconditioned_gradient_batch_size(self) -> Optional[int]:
Expand All @@ -665,10 +668,14 @@ def truncate_preconditioned_gradient(self, keep_size: int) -> None:
:keep_size
].clone()

def _preconditioned_gradient_available(self) -> bool:
"""Checks if the preconditioned matrices are currently stored in the storage."""
return self._storage[PRECONDITIONED_GRADIENT_NAME] is not None

@torch.no_grad()
def synchronize_preconditioned_gradient(self, num_processes: int) -> None:
"""Stacks preconditioned gradient across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and torch.cuda.is_available():
if dist.is_initialized() and torch.cuda.is_available() and self._preconditioned_gradient_available():
if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list):
assert len(self._storage[PRECONDITIONED_GRADIENT_NAME]) == 2
for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])):
Expand Down Expand Up @@ -746,7 +753,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
# The preconditioned gradient is stored as a low-rank approximation.
left_mat, right_mat = self._storage[PRECONDITIONED_GRADIENT_NAME]
self._storage[PAIRWISE_SCORE_MATRIX_NAME] = contract(
"qki,toi,qok->qt",
"qci,toi,qok->qt",
right_mat,
self._cached_per_sample_gradient,
left_mat,
Expand Down Expand Up @@ -854,5 +861,6 @@ def release_scores(self) -> None:
del self._storage[SELF_SCORE_VECTOR_NAME]
self._storage[SELF_SCORE_VECTOR_NAME] = None
self._cached_activations = []
del self._cached_per_sample_gradient
self._cached_per_sample_gradient = None
self._storge_at_current_device = False
Loading

0 comments on commit 3554110

Please sign in to comment.