Skip to content

Commit

Permalink
Clean up factor arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 1, 2024
1 parent f403704 commit 312f43e
Show file tree
Hide file tree
Showing 17 changed files with 724 additions and 653 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_svd.py
Expand Down
80 changes: 38 additions & 42 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,110 +32,106 @@ class FactorArguments(Arguments):
# General configuration. #
strategy: str = field(
default="ekfac",
metadata={"help": "Strategy for computing preconditioning factors."},
metadata={
"help": "Specifies the algorithm for computing influence factors. Default is 'ekfac' "
"(Eigenvalue-corrected Kronecker-factored Approximate Curvature)."
},
)
use_empirical_fisher: bool = field(
default=False,
metadata={
"help": "Whether to use empirical Fisher (using labels from batch) instead of "
"help": "Determines whether to approximate empirical Fisher (using true labels) or "
"true Fisher (using sampled labels)."
},
)
distributed_sync_steps: int = field(
distributed_sync_interval: int = field(
default=1_000,
metadata={
"help": "Specifies the total iteration step to synchronize the process when using distributed setting."
},
metadata={"help": "Number of iterations between synchronization steps in distributed computing settings."},
)
amp_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Dtype for automatic mixed precision (AMP). Disables AMP if None."},
metadata={"help": "Data type for automatic mixed precision (AMP). If `None`, AMP is disabled."},
)
shared_parameters_exist: bool = field(
has_shared_parameters: bool = field(
default=False,
metadata={"help": "Specifies whether the shared parameters exist in the forward pass."},
metadata={"help": "Indicates whether shared parameters are present in the model's forward pass."},
)

# Configuration for fitting covariance matrices. #
covariance_max_examples: Optional[int] = field(
default=100_000,
metadata={
"help": "Maximum number of examples for fitting covariance matrices. "
"Uses all data examples for the given dataset if None."
"help": "Maximum number of examples to use when fitting covariance matrices. "
"Uses entire dataset if `None`."
},
)
covariance_data_partition_size: int = field(
covariance_data_partitions: int = field(
default=1,
metadata={
"help": "Number of data partitions for computing covariance matrices. "
"For example, when `covariance_data_partition_size = 2`, the dataset is split "
"into 2 chunks and covariance matrices are separately computed for each chunk."
},
metadata={"help": "Number of partitions to divide the dataset into for covariance matrix computation."},
)
covariance_module_partition_size: int = field(
covariance_module_partitions: int = field(
default=1,
metadata={
"help": "Number of module partitions for computing covariance matrices. "
"For example, when `covariance_module_partition_size = 2`, the modules (layers) are split "
"into 2 chunks and covariance matrices are separately computed for each chunk."
"help": "Number of partitions to divide the model's modules (layers) into for "
"covariance matrix computation."
},
)
activation_covariance_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing activation covariance matrices."},
metadata={"help": "Data type for activation covariance computations."},
)
gradient_covariance_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing pseudo-gradient covariance matrices."},
metadata={"help": "Data type for pseudo-gradient covariance computations."},
)

# Configuration for performing eigendecomposition. #
eigendecomposition_dtype: torch.dtype = field(
default=torch.float64,
metadata={"help": "Dtype for performing Eigendecomposition. Recommended to use `torch.float64`."},
metadata={
"help": "Data type for eigendecomposition computations. Double precision (`torch.float64`) is "
"recommended for numerical stability."
},
)

# Configuration for fitting Lambda matrices. #
lambda_max_examples: Optional[int] = field(
default=100_000,
metadata={
"help": "Maximum number of examples for fitting Lambda matrices. "
"Uses all data examples for the given dataset if None."
"help": "Maximum number of examples to use when fitting Lambda matrices. Uses entire dataset if `None`."
},
)
lambda_data_partition_size: int = field(
lambda_data_partitions: int = field(
default=1,
metadata={
"help": "Number of data partitions for computing Lambda matrices. "
"For example, when `lambda_data_partition_size = 2`, the dataset is split "
"into 2 chunks and Lambda matrices are separately computed for each chunk."
},
metadata={"help": "Number of partitions to divide the dataset into for Lambda matrix computation."},
)
lambda_module_partition_size: int = field(
lambda_module_partitions: int = field(
default=1,
metadata={
"help": "Number of module partitions for computing Lambda matrices. "
"For example, when `lambda_module_partition_size = 2`, the modules (layers) are split "
"into 2 chunks and Lambda matrices are separately computed for each chunk."
"help": "Number of partitions to divide the model's modules (layers) into for Lambda matrix computation."
},
)
lambda_iterative_aggregate: bool = field(
use_iterative_lambda_aggregation: bool = field(
default=False,
metadata={
"help": "Whether to aggregate squared sum of projected per-sample-gradient with for-loop iterations."
"help": "If True, aggregates the squared sum of projected per-sample gradients "
"iteratively to reduce GPU memory usage."
},
)
cached_activation_cpu_offload: bool = field(
offload_activations_to_cpu: bool = field(
default=False,
metadata={"help": "Whether to offload cached activation to CPU for computing the per-sample-gradient."},
metadata={
"help": "If True, offloads cached activations to CPU memory when computing "
"per-sample gradients, reducing GPU memory usage."
},
)
per_sample_gradient_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing per-sample-gradients."},
metadata={"help": "Data type for per-sample-gradient computations."},
)
lambda_dtype: torch.dtype = field(
default=torch.float32,
metadata={"help": "Dtype for computing Lambda (corrected eigenvalues) matrices."},
metadata={"help": "Data type for Lambda matrix computations."},
)


Expand Down
28 changes: 14 additions & 14 deletions kronfluence/computer/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,34 +249,34 @@ def _configure_dataloader(self, dataloader_kwargs: DataLoaderKwargs) -> Dict[str
def _get_data_partition(
self,
total_data_examples: int,
data_partition_size: int,
data_partitions: int,
target_data_partitions: Optional[Union[int, List[int]]],
) -> Tuple[List[Tuple[int, int]], List[int]]:
"""Partitions the dataset into several chunks."""
if total_data_examples < data_partition_size:
if total_data_examples < data_partitions:
error_msg = (
f"Data partition size ({data_partition_size}) cannot be greater than the "
f"Data partition size ({data_partitions}) cannot be greater than the "
f"total data points ({total_data_examples}). Please reduce the data partition "
f"size in the argument."
)
self.logger.error(error_msg)
raise ValueError(error_msg)

indices_partitions = make_indices_partition(
total_data_examples=total_data_examples, partition_size=data_partition_size
total_data_examples=total_data_examples, partition_size=data_partitions
)

if target_data_partitions is None:
target_data_partitions = list(range(data_partition_size))
target_data_partitions = list(range(data_partitions))

if isinstance(target_data_partitions, int):
target_data_partitions = [target_data_partitions]

for data_partition in target_data_partitions:
if data_partition < 0 or data_partition > data_partition_size:
if data_partition < 0 or data_partition > data_partitions:
error_msg = (
f"Invalid data partition {data_partition} encountered. "
f"The module partition needs to be in between [0, {data_partition_size})."
f"The module partition needs to be in between [0, {data_partitions})."
)
self.logger.error(error_msg)
raise ValueError(error_msg)
Expand All @@ -285,15 +285,15 @@ def _get_data_partition(

def _get_module_partition(
self,
module_partition_size: int,
module_partitions: int,
target_module_partitions: Optional[Union[int, List[int]]],
) -> Tuple[List[List[str]], List[int]]:
"""Partitions the modules into several chunks."""
tracked_module_names = get_tracked_module_names(self.model)

if len(tracked_module_names) < module_partition_size:
if len(tracked_module_names) < module_partitions:
error_msg = (
f"Module partition size ({module_partition_size}) cannot be greater than the "
f"Module partition size ({module_partitions}) cannot be greater than the "
f"total tracked modules ({len(tracked_module_names)}). Please reduce the module partition "
f"size in the argument."
)
Expand All @@ -302,20 +302,20 @@ def _get_module_partition(

modules_partition_list = make_modules_partition(
total_module_names=tracked_module_names,
partition_size=module_partition_size,
partition_size=module_partitions,
)

if target_module_partitions is None:
target_module_partitions = list(range(module_partition_size))
target_module_partitions = list(range(module_partitions))

if isinstance(target_module_partitions, int):
target_module_partitions = [target_module_partitions]

for module_partition in target_module_partitions:
if module_partition < 0 or module_partition > module_partition_size:
if module_partition < 0 or module_partition > module_partitions:
error_msg = (
f"Invalid module partition {module_partition} encountered. "
f"The module partition needs to be in between [0, {module_partition_size})."
f"The module partition needs to be in between [0, {module_partitions})."
)
self.logger.error(error_msg)
raise ValueError(error_msg)
Expand Down
61 changes: 33 additions & 28 deletions kronfluence/computer/factor_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def _configure_and_save_factor_args(
def _aggregate_factors(
self,
factors_name: str,
data_partition_size: int,
module_partition_size: int,
exists_fnc: Callable,
data_partitions: int,
module_partitions: int,
exist_fnc: Callable,
load_fnc: Callable,
save_fnc: Callable,
) -> Optional[FACTOR_TYPE]:
Expand All @@ -73,18 +73,18 @@ def _aggregate_factors(
self.logger.error(error_msg)
raise FileNotFoundError(error_msg)

all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)]
all_required_partitions = [(i, j) for i in range(data_partitions) for j in range(module_partitions)]
all_partition_exists = all(
exists_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions
exist_fnc(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions
)
if not all_partition_exists:
self.logger.warning("Factors are not aggregated as factors for some partitions are not yet computed.")
return

start_time = time.time()
aggregated_factors: FACTOR_TYPE = {}
for data_partition in range(data_partition_size):
for module_partition in range(module_partition_size):
for data_partition in range(data_partitions):
for module_partition in range(module_partitions):
loaded_factors = load_fnc(
output_dir=factors_output_dir,
partition=(data_partition, module_partition),
Expand All @@ -95,9 +95,10 @@ def _aggregate_factors(

for module_name in factors:
if module_name not in aggregated_factors[factor_name]:
aggregated_factors[factor_name][module_name] = factors[module_name]
else:
aggregated_factors[factor_name][module_name].add_(factors[module_name])
aggregated_factors[factor_name][module_name] = torch.zeros(
size=factors[module_name].shape, dtype=factors[module_name].dtype, requires_grad=False
)
aggregated_factors[factor_name][module_name].add_(factors[module_name])
del loaded_factors
save_fnc(
output_dir=factors_output_dir,
Expand Down Expand Up @@ -227,9 +228,7 @@ def fit_covariance_matrices(
total_data_examples = min([factor_args.covariance_max_examples, len(dataset)])
self.logger.info(f"Total data examples to fit covariance matrices: {total_data_examples}.")

no_partition = (
factor_args.covariance_data_partition_size == 1 and factor_args.covariance_module_partition_size == 1
)
no_partition = factor_args.covariance_data_partitions == 1 and factor_args.covariance_module_partitions == 1
partition_provided = target_data_partitions is not None or target_module_partitions is not None
if no_partition and partition_provided:
error_msg = (
Expand All @@ -241,17 +240,20 @@ def fit_covariance_matrices(

data_partition_indices, target_data_partitions = self._get_data_partition(
total_data_examples=total_data_examples,
data_partition_size=factor_args.covariance_data_partition_size,
data_partitions=factor_args.covariance_data_partitions,
target_data_partitions=target_data_partitions,
)
max_partition_examples = total_data_examples // factor_args.covariance_data_partition_size
max_partition_examples = total_data_examples // factor_args.covariance_data_partitions
module_partition_names, target_module_partitions = self._get_module_partition(
module_partition_size=factor_args.covariance_module_partition_size,
module_partitions=factor_args.covariance_module_partitions,
target_module_partitions=target_module_partitions,
)

if max_partition_examples < self.state.num_processes:
error_msg = "The number of processes are more than the data examples. Try reducing the number of processes."
error_msg = (
"The number of processes are larger than the total data examples. "
"Try reducing the number of processes."
)
self.logger.error(error_msg)
raise ValueError(error_msg)

Expand Down Expand Up @@ -372,9 +374,9 @@ def aggregate_covariance_matrices(
with self.profiler.profile("Aggregate Covariance"):
self._aggregate_factors(
factors_name=factors_name,
data_partition_size=factor_args.covariance_data_partition_size,
module_partition_size=factor_args.covariance_module_partition_size,
exists_fnc=covariance_matrices_exist,
data_partitions=factor_args.covariance_data_partitions,
module_partitions=factor_args.covariance_module_partitions,
exist_fnc=covariance_matrices_exist,
load_fnc=load_covariance_matrices,
save_fnc=save_covariance_matrices,
)
Expand Down Expand Up @@ -580,7 +582,7 @@ def fit_lambda_matrices(
total_data_examples = min([factor_args.lambda_max_examples, len(dataset)])
self.logger.info(f"Total data examples to fit Lambda matrices: {total_data_examples}.")

no_partition = factor_args.lambda_data_partition_size == 1 and factor_args.lambda_module_partition_size == 1
no_partition = factor_args.lambda_data_partitions == 1 and factor_args.lambda_module_partitions == 1
partition_provided = target_data_partitions is not None or target_module_partitions is not None
if no_partition and partition_provided:
error_msg = (
Expand All @@ -592,17 +594,20 @@ def fit_lambda_matrices(

data_partition_indices, target_data_partitions = self._get_data_partition(
total_data_examples=total_data_examples,
data_partition_size=factor_args.lambda_data_partition_size,
data_partitions=factor_args.lambda_data_partitions,
target_data_partitions=target_data_partitions,
)
max_partition_examples = total_data_examples // factor_args.lambda_data_partition_size
max_partition_examples = total_data_examples // factor_args.lambda_data_partitions
module_partition_names, target_module_partitions = self._get_module_partition(
module_partition_size=factor_args.lambda_module_partition_size,
module_partitions=factor_args.lambda_module_partitions,
target_module_partitions=target_module_partitions,
)

if max_partition_examples < self.state.num_processes:
error_msg = "The number of processes are more than the data examples. Try reducing the number of processes."
error_msg = (
"The number of processes are larger than the total data examples. "
"Try reducing the number of processes."
)
self.logger.error(error_msg)
raise ValueError(error_msg)

Expand Down Expand Up @@ -725,9 +730,9 @@ def aggregate_lambda_matrices(
with self.profiler.profile("Aggregate Lambda"):
self._aggregate_factors(
factors_name=factors_name,
data_partition_size=factor_args.lambda_data_partition_size,
module_partition_size=factor_args.lambda_module_partition_size,
exists_fnc=lambda_matrices_exist,
data_partitions=factor_args.lambda_data_partitions,
module_partitions=factor_args.lambda_module_partitions,
exist_fnc=lambda_matrices_exist,
load_fnc=load_lambda_matrices,
save_fnc=save_lambda_matrices,
)
Loading

0 comments on commit 312f43e

Please sign in to comment.