Skip to content

Commit

Permalink
Add score computations
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 6, 2024
1 parent a80595e commit 42e539b
Show file tree
Hide file tree
Showing 11 changed files with 381 additions and 338 deletions.
6 changes: 3 additions & 3 deletions kronfluence/module/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F
from einconv.utils import get_conv_paddings
from einops import rearrange, reduce
from opt_einsum import DynamicProgramming, contract_expression
from opt_einsum import DynamicProgramming, contract_expression, contract
from torch import nn
from torch.nn.modules.utils import _pair

Expand Down Expand Up @@ -160,7 +160,7 @@ def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradien
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
summed_gradient = torch.einsum("bci,bco->io", output_gradient, input_activation)
summed_gradient = contract("bci,bco->io", output_gradient, input_activation)
return summed_gradient.view((1, *summed_gradient.size()))

def compute_per_sample_gradient(
Expand All @@ -171,7 +171,7 @@ def compute_per_sample_gradient(
input_activation = self._flatten_input_activation(input_activation=input_activation)
input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1))
output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o")
per_sample_gradient = torch.einsum("bci,bco->bio", output_gradient, input_activation)
per_sample_gradient = contract("bci,bco->bio", output_gradient, input_activation)
if self.per_sample_gradient_process_fnc is not None:
per_sample_gradient = self.per_sample_gradient_process_fnc(
module_name=self.name, gradient=per_sample_gradient
Expand Down
64 changes: 24 additions & 40 deletions kronfluence/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from einops import rearrange
from opt_einsum import DynamicProgramming, contract_expression
from opt_einsum import DynamicProgramming, contract, contract_expression
from torch import nn

from kronfluence.module.tracked_module import TrackedModule
Expand Down Expand Up @@ -62,14 +62,14 @@ def _flatten_input_activation(self, input_activation: torch.Tensor) -> torch.Ten

def compute_summed_gradient(self, input_activation: torch.Tensor, output_gradient: torch.Tensor) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
summed_gradient = torch.einsum("b...i,b...o->io", output_gradient, input_activation)
return summed_gradient.view((1, *summed_gradient.size()))
summed_gradient = contract("b...i,b...o->io", output_gradient, input_activation).unsqueeze_(0)
return summed_gradient

def compute_per_sample_gradient(
self, input_activation: torch.Tensor, output_gradient: torch.Tensor
) -> torch.Tensor:
input_activation = self._flatten_input_activation(input_activation=input_activation)
per_sample_gradient = torch.einsum("b...i,b...o->bio", output_gradient, input_activation)
per_sample_gradient = contract("b...i,b...o->bio", output_gradient, input_activation)
if self.per_sample_gradient_process_fnc is not None:
per_sample_gradient = self.per_sample_gradient_process_fnc(
module_name=self.name, gradient=per_sample_gradient
Expand All @@ -82,53 +82,37 @@ def compute_pairwise_score(
input_activation = self._flatten_input_activation(input_activation=input_activation)
if isinstance(preconditioned_gradient, list):
left_mat, right_mat = preconditioned_gradient

if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
self.einsum_expression = contract_expression(
"qik,qko,bti,bto->qbt",
left_mat.shape,
right_mat.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
)
expr = "qik,qko,bti,bto->qbt"
else:
self.einsum_expression = contract_expression(
"qik,qko,b...i,b...o->qb",
left_mat.shape,
right_mat.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation)

if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
expr = "qik,qko,b...i,b...o->qb"
self.einsum_expression = contract_expression(
"qio,bti,bto->qbt",
preconditioned_gradient.shape,
expr,
left_mat.shape,
right_mat.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
)
return self.einsum_expression(left_mat, right_mat, output_gradient, input_activation)

if self.einsum_expression is None:
if self.score_args.compute_per_token_scores and len(input_activation.shape) == 3:
expr = "qio,bti,bto->qbt"
else:
self.einsum_expression = contract_expression(
"qio,b...i,b...o->qb",
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
)
expr = "qio,b...i,b...o->qb"
self.einsum_expression = contract_expression(
expr,
preconditioned_gradient.shape,
output_gradient.shape,
input_activation.shape,
optimize=DynamicProgramming(
search_outer=True, minimize="size" if self.score_args.einsum_minimize_size else "flops"
),
)
return self.einsum_expression(preconditioned_gradient, output_gradient, input_activation)

def compute_self_measurement_score(
Expand Down
6 changes: 2 additions & 4 deletions kronfluence/module/tracker/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale
self.cached_activations = self.cached_activations.to(device=output_gradient.device)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations,
input_activation=self.cached_activations.to(device=output_gradient.device),
output_gradient=output_gradient,
).to(dtype=self.module.factor_args.lambda_dtype)
self.clear_all_cache()
Expand All @@ -281,9 +280,8 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None:
else:
output_gradient = output_gradient * self.module.gradient_scale
cached_activation = self.cached_activations.pop()
cached_activation = cached_activation.to(device=output_gradient.device)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation,
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
)
if self.cached_per_sample_gradient is None:
Expand Down
117 changes: 46 additions & 71 deletions kronfluence/module/tracker/gradient.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,65 @@
from typing import List, Tuple
from typing import Tuple

import torch
import torch.distributed as dist
import torch.nn as nn

from kronfluence.factor.config import FactorConfig
from kronfluence.module.tracker.base import BaseTracker
from kronfluence.utils.constants import (
ACCUMULATED_PRECONDITIONED_GRADIENT_NAME,
AGGREGATED_GRADIENT_NAME,
PRECONDITIONED_GRADIENT_NAME,
)
from kronfluence.utils.constants import AGGREGATED_GRADIENT_NAME


class GradientTracker(BaseTracker):
"""Tracks and computes summed gradient for a given module."""
"""Tracks and computes aggregated gradient for a given module."""

def register_hooks(self) -> None:
"""Sets up hooks to compute and keep track of summed gradient."""
"""Sets up hooks to compute and keep track of aggregated gradient."""

@torch.no_grad()
def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor) -> None:
del module
cached_activation = inputs[0].detach()
device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device
cached_activation = cached_activation.to(
device=device,
dtype=self.module.score_args.per_sample_gradient_dtype,
copy=True,
)

if self.module.factor_args.has_shared_parameters:
if self.cached_activations is None:
self.cached_activations = []
self.cached_activations.append(cached_activation)
else:
self.cached_activations = cached_activation
with torch.no_grad():
cached_activation = inputs[0].detach()
device = "cpu" if self.module.score_args.offload_activations_to_cpu else cached_activation.device
cached_activation = cached_activation.to(
device=device,
dtype=self.module.score_args.per_sample_gradient_dtype,
copy=True,
)
if self.module.factor_args.has_shared_parameters:
if self.cached_activations is None:
self.cached_activations = []
self.cached_activations.append(cached_activation)
else:
self.cached_activations = cached_activation

outputs.register_hook(
shared_backward_hook if self.module.factor_args.has_shared_parameters else backward_hook
)
self.cached_hooks.append(outputs.register_hook(backward_hook))

@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
if self.cached_activations is None:
self._raise_cache_not_found_exception()

output_gradient = self._scale_output_gradient(
output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype
)
handle = self.cached_hooks.pop()
handle.remove()
original_dtype = output_gradient.dtype
target_dtype = self.module.score_args.per_sample_gradient_dtype
output_gradient = output_gradient.detach().to(dtype=target_dtype)
if self.module.gradient_scale != 1.0:
if original_dtype != target_dtype:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale

if isinstance(self.cached_activations, list):
cached_activation = self.cached_activations.pop()
else:
cached_activation = self.cached_activations
if self.module.per_sample_gradient_process_fnc is None:
summed_gradient = self.module.compute_summed_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
)
else:
summed_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
).sum(dim=0, keepdim=True)
self.clear_all_cache()
Expand All @@ -65,49 +68,16 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(summed_gradient, requires_grad=False)
self.module.storage[AGGREGATED_GRADIENT_NAME].add_(summed_gradient)

@torch.no_grad()
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient = self._scale_output_gradient(
output_gradient=output_gradient, target_dtype=self.module.score_args.per_sample_gradient_dtype
)
cached_activation = self.cached_activations.pop()
if self.module.per_sample_gradient_process_fnc is None:
summed_gradient = self.module.compute_summed_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
)
else:
summed_gradient = self.module.comute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
).sum(dim=0, keepdim=True)

if self.cached_per_sample_gradient is None:
self.cached_per_sample_gradient = torch.zeros_like(summed_gradient, requires_grad=False)
self.cached_per_sample_gradient.add_(summed_gradient)

self.registered_hooks.append(self.module.original_module.register_forward_hook(forward_hook))

def exist(self) -> bool:
return self.module.storage[AGGREGATED_GRADIENT_NAME] is not None
self.registered_hooks.append(self.module.register_forward_hook(forward_hook))

@torch.no_grad()
def finalize_iteration(self):
"""Computes preconditioned gradient using cached per-sample gradients."""
if not self.module.factor_args.has_shared_parameters:
return
if self.module.storage[AGGREGATED_GRADIENT_NAME] is None:
self.module.storage[AGGREGATED_GRADIENT_NAME] = torch.zeros_like(
self.cached_per_sample_gradient, requires_grad=False
)
self.module.storage[AGGREGATED_GRADIENT_NAME].add_(self.cached_per_sample_gradient)
"""Clears all cached activations from memory."""
self.clear_all_cache()

def release_memory(self) -> None:
"""Clears summed gradients from memory."""
del self.module.storage[AGGREGATED_GRADIENT_NAME]
self.module.storage[AGGREGATED_GRADIENT_NAME] = None
self.clear_all_cache()
def exist(self) -> bool:
"""Checks if aggregated gradient is available."""
return self.module.storage[AGGREGATED_GRADIENT_NAME] is not None

def synchronize(self, num_processes: int = 1) -> None:
"""Aggregates summed gradient across multiple devices or nodes in a distributed setting."""
Expand All @@ -124,3 +94,8 @@ def synchronize(self, num_processes: int = 1) -> None:
tensor=self.module.storage[AGGREGATED_GRADIENT_NAME],
op=dist.ReduceOp.SUM,
)

def release_memory(self) -> None:
"""Clears aggregated gradients from memory."""
self.clear_all_cache()
self.module.storage[AGGREGATED_GRADIENT_NAME] = None
Loading

0 comments on commit 42e539b

Please sign in to comment.