Skip to content

Commit

Permalink
m
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 5, 2024
1 parent c135ecd commit 288427e
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 32 deletions.
10 changes: 8 additions & 2 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def fit_covariance_matrices(
total_data_examples=max_partition_examples,
)

self._reset_memory()
start_time = get_time(state=self.state)
with self.profiler.profile("Fit Covariance"):
loader = self._get_dataloader(
Expand Down Expand Up @@ -331,8 +332,9 @@ def fit_covariance_matrices(
metadata=factor_args.to_str_dict(),
)
self.state.wait_for_everyone()
del covariance_factors, loader
self.logger.info(f"Saved covariance matrices at `{factors_output_dir}`.")
del num_data_processed, covariance_factors, loader
self._reset_memory()

all_end_time = get_time(state=self.state)
elapsed_time = all_end_time - all_start_time
Expand Down Expand Up @@ -442,6 +444,7 @@ def perform_eigendecomposition(
)
self.state.wait_for_everyone()

self._reset_memory()
eigen_factors = None
if self.state.is_main_process:
start_time = time.time()
Expand All @@ -462,6 +465,7 @@ def perform_eigendecomposition(
output_dir=factors_output_dir, factors=eigen_factors, metadata=factor_args.to_str_dict()
)
self.logger.info(f"Saved eigendecomposition results at `{factors_output_dir}`.")
self._reset_memory()
self.state.wait_for_everyone()
self._log_profile_summary(name=f"factors_{factors_name}_eigendecomposition")

Expand Down Expand Up @@ -645,6 +649,7 @@ def fit_lambda_matrices(
total_data_examples=max_partition_examples,
)

self._reset_memory()
start_time = get_time(state=self.state)
with self.profiler.profile("Fit Lambda"):
loader = self._get_dataloader(
Expand Down Expand Up @@ -680,8 +685,9 @@ def fit_lambda_matrices(
metadata=factor_args.to_str_dict(),
)
self.state.wait_for_everyone()
del lambda_factors, loader
self.logger.info(f"Saved Lambda matrices at `{factors_output_dir}`.")
del num_data_processed, lambda_factors, loader
self._reset_memory()

all_end_time = get_time(state=self.state)
elapsed_time = all_end_time - all_start_time
Expand Down
3 changes: 1 addition & 2 deletions kronfluence/factor/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def fit_covariance_matrices_with_loader(
mode=ModuleMode.COVARIANCE,
release_memory=True,
)
release_memory()

total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
Expand Down Expand Up @@ -233,6 +232,7 @@ def fit_covariance_matrices_with_loader(
state.wait_for_everyone()

num_data_processed.add_(find_batch_size(data=batch))
del batch, attention_mask, loss
total_steps += 1
pbar.update(1)

Expand Down Expand Up @@ -260,7 +260,6 @@ def fit_covariance_matrices_with_loader(
if enable_amp:
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
release_memory()
state.wait_for_everyone()

return num_data_processed, saved_factors
3 changes: 1 addition & 2 deletions kronfluence/factor/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def fit_lambda_matrices_with_loader(
if eigen_factors is not None:
for name in eigen_factors:
set_factors(model=model, factor_name=name, factors=eigen_factors[name])
release_memory()

total_steps = 0
num_data_processed = torch.zeros((1,), dtype=torch.int64, requires_grad=False)
Expand Down Expand Up @@ -420,6 +419,7 @@ def fit_lambda_matrices_with_loader(
sample=not factor_args.use_empirical_fisher,
)
scaler.scale(loss).backward()
del loss

if factor_args.has_shared_parameters:
finalize_iteration(model=model, tracked_module_names=tracked_module_names)
Expand Down Expand Up @@ -459,6 +459,5 @@ def fit_lambda_matrices_with_loader(
set_gradient_scale(model=model, gradient_scale=1.0)
set_mode(model=model, mode=ModuleMode.DEFAULT, release_memory=True)
state.wait_for_everyone()
release_memory()

return num_data_processed, saved_factors
7 changes: 5 additions & 2 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PRECONDITIONED_GRADIENT_TYPE,
SELF_SCORE_VECTOR_NAME,
)
from kronfluence.utils.state import State


class ModuleMode(str, BaseEnum):
Expand Down Expand Up @@ -88,13 +89,14 @@ def __init__(
self._constant: torch.Tensor = nn.Parameter(
torch.zeros(
1,
dtype=self.original_module.weight.dtype,
requires_grad=True,
dtype=torch.float16,
)
)
self.current_mode = ModuleMode.DEFAULT
self.factor_args = FactorArguments() if factor_args is None else factor_args
self.score_args = ScoreArguments() if score_args is None else score_args
self.state = State()
self.per_sample_gradient_process_fnc = per_sample_gradient_process_fnc
self.einsum_expression = None

Expand Down Expand Up @@ -134,7 +136,8 @@ def __init__(

def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
"""A forward pass of the tracked module. This should have identical behavior to that of the original module."""
return self.original_module(inputs + self._constant, *args, **kwargs)
# return self.original_module(inputs + self._constant, *args, **kwargs)
return self.original_module(inputs, *args, **kwargs) + self._constant

def prepare_storage(self, device: torch.device) -> None:
"""Performs any necessary operations on storage before computing any metrics."""
Expand Down
6 changes: 6 additions & 0 deletions kronfluence/module/tracker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, module: nn.Module) -> None:
"""
self.module = module
self.registered_hooks: List[RemovableHandle] = []
self.cached_hooks: List[RemovableHandle] = []
self.cached_activations: Optional[Union[List[torch.Tensor]], torch.Tensor] = None
self.cached_per_sample_gradient: Optional[torch.Tensor] = None

Expand All @@ -32,6 +33,11 @@ def clear_all_cache(self) -> None:
del self.cached_activations, self.cached_per_sample_gradient
self.cached_activations, self.cached_per_sample_gradient = None, None

while self.cached_hooks:
handle = self.cached_hooks.pop()
handle.remove()
self.cached_hooks = []

def _scale_output_gradient(self, output_gradient: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor:
"""Scales the output gradient and convert to the target dtype.
Expand Down
57 changes: 33 additions & 24 deletions kronfluence/module/tracker/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,20 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
def register_hooks(self) -> None:
"""Sets up hooks to compute activation and gradient covariance matrices."""

@torch.no_grad()
def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None:
del module
# Computes and updates activation covariance during forward pass.
input_activation = inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype)
self._update_activation_covariance_matrix(input_activation=input_activation)
outputs.register_hook(backward_hook)
with torch.no_grad():
# Computes and updates activation covariance during forward pass.
input_activation = (
inputs[0].detach().to(dtype=self.module.factor_args.activation_covariance_dtype, copy=True)
)
self._update_activation_covariance_matrix(input_activation=input_activation)
self.cached_hooks.append(outputs.register_hook(backward_hook))

@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
# Computes and updates pseudo-gradient covariance during backward pass.
original_dtype = output_gradient.dtype
target_dtype = self.module.factor_args.gradient_covariance_dtype
Expand All @@ -103,7 +107,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient = output_gradient * self.module.gradient_scale
self._update_gradient_covariance_matrix(output_gradient=output_gradient)

self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook))
# self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook))
self.registered_hooks.append(self.module.register_forward_hook(forward_hook))

def exist(self) -> bool:
"""Checks if both activation and gradient covariance matrices are available."""
Expand All @@ -127,7 +132,6 @@ def synchronize(self, num_processes: int) -> None:
def release_memory(self) -> None:
"""Clears all covariance matrices from memory."""
for covariance_factor_name in COVARIANCE_FACTOR_NAMES:
del self.module.storage[covariance_factor_name]
self.module.storage[covariance_factor_name] = None


Expand Down Expand Up @@ -214,33 +218,36 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None:
def register_hooks(self) -> None:
"""Sets up hooks to compute lambda matrices."""

@torch.no_grad()
def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None:
del module
cached_activation = inputs[0].detach()
device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device
cached_activation = cached_activation.to(
device=device,
dtype=self.module.factor_args.per_sample_gradient_dtype,
copy=True,
)

if self.module.factor_args.has_shared_parameters:
if self.cached_activations is None:
self.cached_activations = []
self.cached_activations.append(cached_activation)
else:
self.cached_activations = cached_activation
with torch.no_grad():
cached_activation = inputs[0].detach()
device = "cpu" if self.module.factor_args.offload_activations_to_cpu else cached_activation.device
cached_activation = cached_activation.to(
device=device,
dtype=self.module.factor_args.per_sample_gradient_dtype,
copy=True,
)
if self.module.factor_args.has_shared_parameters:
if self.cached_activations is None:
self.cached_activations = []
self.cached_activations.append(cached_activation)
else:
self.cached_activations = cached_activation

outputs.register_hook(
shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook
self.cached_hooks.append(
outputs.register_hook(
shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook
)
)

@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
if self.cached_activations is None:
self._raise_cache_not_found_exception()

handle = self.cached_hooks.pop()
handle.remove()
original_dtype = output_gradient.dtype
target_dtype = self.module.factor_args.per_sample_gradient_dtype
output_gradient = output_gradient.detach().to(dtype=target_dtype)
Expand All @@ -258,6 +265,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:

@torch.no_grad()
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
original_dtype = output_gradient.dtype
target_dtype = self.module.factor_args.per_sample_gradient_dtype
output_gradient = output_gradient.detach().to(dtype=target_dtype)
Expand Down

0 comments on commit 288427e

Please sign in to comment.