Skip to content

Commit

Permalink
Do GPU tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 7, 2024
1 parent 1c603fd commit 743070a
Show file tree
Hide file tree
Showing 20 changed files with 355 additions and 193 deletions.
22 changes: 11 additions & 11 deletions kronfluence/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -45,8 +46,7 @@ def prepare_model(


class Analyzer(FactorComputer, ScoreComputer):
"""Handles the computation of factors (e.g., covariance and Lambda matrices for EKFAC) and
influence scores for a given PyTorch model."""
"""Handles the computation of factors (e.g., covariance matrices) and scores for a given PyTorch model."""

def __init__(
self,
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
output_dir (str):
Directory path for storing analysis results. Defaults to './influence_results'.
disable_model_save (bool, optional):
If `True`, prevents saving the model's state_dict. Defaults to `True`.
If `True`, prevents saving the model's `state_dict`. Defaults to `True`.
Raises:
ValueError:
Expand All @@ -100,8 +100,8 @@ def __init__(
disable_tqdm=disable_tqdm,
output_dir=output_dir,
)
self.logger.info(f"Initializing Computer with parameters: {locals()}")
self.logger.debug(f"Process state configuration:\n{repr(self.state)}")
self.logger.info(f"Initializing `Analyzer` with parameters: {locals()}")
self.logger.info(f"Process state configuration:\n{repr(self.state)}")

# Save model parameters if necessary.
if self.state.is_main_process and not disable_model_save:
Expand All @@ -113,15 +113,15 @@ def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None:
Args:
dataloader_kwargs (DataLoaderKwargs):
The object containing arguments for DataLoader.
The object containing arguments for PyTorch DataLoader.
"""
self._dataloader_params = dataloader_kwargs

@torch.no_grad()
def _save_model(self) -> None:
"""Saves the model to the output directory."""
model_save_path = self.output_dir / "model.safetensors"
extracted_model = extract_model_from_parallel(model=self.model, keep_fp32_wrapper=True)
extracted_model = extract_model_from_parallel(model=copy.deepcopy(self.model), keep_fp32_wrapper=True)

if model_save_path.exists():
self.logger.info(f"Found existing saved model at `{model_save_path}`.")
Expand Down Expand Up @@ -151,13 +151,13 @@ def fit_all_factors(
factor_args: Optional[FactorArguments] = None,
overwrite_output_dir: bool = False,
) -> None:
"""Computes all necessary factors for the given factor strategy.
"""Computes all necessary factors for the given strategy.
Args:
factors_name (str):
Unique identifier for the factor, used for organizing results.
dataset (data.Dataset):
Dataset used to fit all the factors.
Dataset used to fit all influence factors.
per_device_batch_size (int, optional):
Per-device batch size for factor fitting. If not specified, executable per-device batch size
is automatically determined.
Expand All @@ -168,7 +168,7 @@ def fit_all_factors(
factor_args (FactorArguments, optional):
Arguments for factor computation. Defaults to `FactorArguments` default values.
overwrite_output_dir (bool, optional):
If `True`, overwrites existing factors with the same name. Defaults to `False`.
If `True`, overwrites existing factors with the same `factors_name`. Defaults to `False`.
"""
self.fit_covariance_matrices(
factors_name=factors_name,
Expand Down Expand Up @@ -211,7 +211,7 @@ def load_file(path: Union[str, Path]) -> Dict[str, torch.Tensor]:
If the specified file does not exist.
Note:
For more information on safetensors, see https://github.com/huggingface/safetensors.
For more information on `safetensors`, see https://github.com/huggingface/safetensors.
"""
if isinstance(path, str):
path = Path(path).resolve()
Expand Down
17 changes: 5 additions & 12 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,6 @@ class ScoreArguments(Arguments):
default=False,
metadata={"help": "If `True`, offloads cached activations to CPU memory when computing per-sample gradients."},
)
einsum_minimize_size: bool = field(
default=False,
metadata={
"help": "If `True`, einsum operations find the contraction that minimizes the size of the "
"largest intermediate tensor."
},
)

# Partition configuration #
data_partitions: int = field(
Expand Down Expand Up @@ -209,7 +202,7 @@ class ScoreArguments(Arguments):
query_gradient_low_rank: Optional[int] = field(
default=None,
metadata={
"help": "Rank for the low-rank approximation of the query gradient. "
"help": "Rank for the low-rank approximation of the query gradient (query batching). "
"If `None`, no low-rank approximation is applied."
},
)
Expand Down Expand Up @@ -248,7 +241,7 @@ class ScoreArguments(Arguments):
)
per_sample_gradient_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Data type for per-sample gradient computation."},
metadata={"help": "Data type for query per-sample gradient computation."},
)
precondition_dtype: torch.dtype = field(
default=torch.float32,
Expand All @@ -260,8 +253,8 @@ class ScoreArguments(Arguments):
)

def __post_init__(self) -> None:
if self.damping_factor is not None and self.damping_factor <= 0:
raise ValueError("`damping_factor` must be None or positive.")
if self.damping_factor is not None and self.damping_factor < 0:
raise ValueError("`damping_factor` must be `None` or positive.")

if any(partition <= 0 for partition in [self.data_partitions, self.module_partitions]):
raise ValueError("Both data and module partitions must be positive.")
Expand All @@ -270,4 +263,4 @@ def __post_init__(self) -> None:
raise ValueError("`query_gradient_accumulation_steps` must be positive.")

if self.query_gradient_low_rank is not None and self.query_gradient_low_rank <= 0:
raise ValueError("`query_gradient_low_rank` must be None or positive.")
raise ValueError("`query_gradient_low_rank` must be `None` or positive.")
2 changes: 1 addition & 1 deletion kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PARTITION_TYPE,
)
from kronfluence.utils.logger import TQDM_BAR_FORMAT
from kronfluence.utils.state import State, no_sync, release_memory
from kronfluence.utils.state import State, no_sync


def covariance_matrices_save_path(
Expand Down
4 changes: 1 addition & 3 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ def perform_eigendecomposition(

pbar.update(1)

release_memory()

return eigen_factors


Expand Down Expand Up @@ -391,7 +389,7 @@ def fit_lambda_matrices_with_loader(
)
if eigen_factors is not None:
for name in eigen_factors:
set_factors(model=model, factor_name=name, factors=eigen_factors[name])
set_factors(model=model, factor_name=name, factors=eigen_factors[name], clone=True)

total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
Expand Down
18 changes: 5 additions & 13 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F
from einconv.utils import get_conv_paddings
from einops import rearrange, reduce
from opt_einsum import DynamicProgramming, contract_expression, contract
from opt_einsum import DynamicProgramming, contract, contract_expression
from torch import nn
from torch.nn.modules.utils import _pair

Expand Down Expand Up @@ -116,7 +116,6 @@ def get_flattened_activation(self, input_activation: torch.Tensor) -> Tuple[torc
tensor=input_activation,
pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2",
)

if self.original_module.bias is not None:
input_activation = torch.cat(
[
Expand Down Expand Up @@ -145,7 +144,6 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten
tensor=input_activation,
pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2",
)

if self.original_module.bias is not None:
input_activation = torch.cat(
[
Expand All @@ -160,7 +158,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
summed_gradient = contract("bci,bco->io", output_gradient, input_activation)
summed_gradient = contract("bci,bco->io", output_gradient, input_activation).unsqueeze_(dim=0)
return summed_gradient.view((1, *summed_gradient.size()))

def compute_per_sample_gradient(
Expand Down Expand Up @@ -195,9 +193,7 @@ def compute_pairwise_score(
right_mat.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
optimize=DynamicProgramming(search_outer=True, minimize="size"),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation)

Expand All @@ -207,9 +203,7 @@ def compute_pairwise_score(
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)

Expand All @@ -225,8 +219,6 @@ def compute_self_measurement_score(
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
16 changes: 6 additions & 10 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten

def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(0)
summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(dim=0)
return summed_gradient

def compute_per_sample_gradient(
Expand Down Expand Up @@ -93,25 +93,23 @@ def compute_pairwise_score(
right_mat.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
optimize=DynamicProgramming(search_outer=True, minimize="size"),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation)

if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
expr = "qio,bti,bto->qbt"
minimize = "size"
else:
expr = "qio,b...i,b...o->qb"
minimize = "flops"
self.einsum_expression = contract_expression(
expr,
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
optimize=DynamicProgramming(search_outer=True, minimize=minimize),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)

Expand All @@ -125,8 +123,6 @@ def compute_self_measurement_score(
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
optimize=DynamicProgramming(search_outer=True, minimize="flops"),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)
25 changes: 18 additions & 7 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ class TrackedModule(nn.Module):
SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {}

def __init_subclass__(cls, module_type: Type[nn.Module] = None, **kwargs: Any) -> None:
"""Automatically registers subclasses as supported modules."""
"""Automatically registers subclasses as supported modules.
Args:
module_type (Type[nn.Module], optional):
The type of module this subclass supports.
**kwargs:
Additional keyword arguments.
"""
super().__init_subclass__(**kwargs)
if module_type is not None:
cls.SUPPORTED_MODULES[module_type] = cls
Expand All @@ -75,17 +82,16 @@ def __init__(
original_module (nn.Module):
The original module to be wrapped.
factor_args (FactorArguments, optional):
Arguments for computing influence factors.
Arguments for computing factors.
score_args (ScoreArguments, optional):
Arguments for computing influence scores.
per_sample_gradient_process_fnc (Callable, optional):
Function to post-process per-sample gradients.
Optional function to post-process per-sample gradients.
"""
super().__init__()

self.name = name
self.original_module = original_module
# A way to avoid Autograd computing the gradient with respect to the model parameters.
self._constant: torch.Tensor = nn.Parameter(
torch.zeros(
1,
Expand All @@ -96,9 +102,7 @@ def __init__(
self.current_mode = ModuleMode.DEFAULT
self.factor_args = FactorArguments() if factor_args is None else factor_args
self.score_args = ScoreArguments() if score_args is None else score_args
self.state = State()
self.per_sample_gradient_process_fnc = per_sample_gradient_process_fnc
self.einsum_expression = None

self._trackers = {
ModuleMode.DEFAULT: BaseTracker(self),
Expand All @@ -114,6 +118,13 @@ def __init__(
self.attention_mask: Optional[torch.Tensor] = None
self.gradient_scale: float = 1.0
self.storage: Dict[str, Optional[Union[torch.Tensor, PRECONDITIONED_GRADIENT_TYPE]]] = {}
self.state: State = State()
self.einsum_expression: Optional[Callable] = None

self._initialize_storage()

def _initialize_storage(self) -> None:
"""Initializes trackers for different module modes."""

# Storage for activation and pseudo-gradient covariance matrices #
for covariance_factor_name in COVARIANCE_FACTOR_NAMES:
Expand Down Expand Up @@ -142,7 +153,7 @@ def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
return outputs + self._constant

def prepare_storage(self, device: torch.device) -> None:
"""Performs any necessary operations on storage before computing any metrics."""
"""Performs any necessary operations on storage before computing influence scores."""
FactorConfig.CONFIGS[self.factor_args.strategy].prepare(
storage=self.storage,
score_args=self.score_args,
Expand Down
Loading

0 comments on commit 743070a

Please sign in to comment.