From 2df875b29064ac9eaeced7d67b661e65fa5a821f Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Sun, 7 Jul 2024 15:56:52 -0400 Subject: [PATCH] Modify fsdp --- tests/gpu_tests/fsdp_test.py | 153 +++++++++++++++++------------------ 1 file changed, 72 insertions(+), 81 deletions(-) diff --git a/tests/gpu_tests/fsdp_test.py b/tests/gpu_tests/fsdp_test.py index 8c08975..a8e1773 100644 --- a/tests/gpu_tests/fsdp_test.py +++ b/tests/gpu_tests/fsdp_test.py @@ -10,6 +10,7 @@ from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import pytest_factor_arguments from kronfluence.utils.constants import ( ALL_MODULE_NAME, COVARIANCE_FACTOR_NAMES, @@ -18,7 +19,7 @@ from kronfluence.utils.model import apply_fsdp from tests.gpu_tests.pipeline import GpuTestTask, construct_test_mlp, get_mnist_dataset from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES -from tests.utils import check_tensor_dict_equivalence +from tests.utils import check_tensor_dict_equivalence, ATOL, RTOL LOCAL_RANK = int(os.environ["LOCAL_RANK"]) WORLD_RANK = int(os.environ["RANK"]) @@ -62,12 +63,7 @@ def setUpClass(cls) -> None: def test_covariance_matrices(self) -> None: covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_covariance_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -87,18 +83,13 @@ def test_covariance_matrices(self) -> None: assert check_tensor_dict_equivalence( covariance_factors[name], new_covariance_factors[name], - atol=1e-5, - rtol=1e-3, + atol=ATOL, + rtol=RTOL, ) def test_lambda_matrices(self): lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = pytest_factor_arguments() self.analyzer.fit_lambda_matrices( factors_name=NEW_FACTOR_NAME, dataset=self.train_dataset, @@ -119,74 +110,74 @@ def test_lambda_matrices(self): assert check_tensor_dict_equivalence( lambda_factors[name], new_lambda_factors[name], - atol=1e-3, - rtol=1e-1, + atol=ATOL, + rtol=RTOL, ) - def test_pairwise_scores(self) -> None: - pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) - - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) - self.analyzer.compute_pairwise_scores( - scores_name=NEW_SCORE_NAME, - factors_name=OLD_FACTOR_NAME, - query_dataset=self.eval_dataset, - train_dataset=self.train_dataset, - train_indices=list(range(TRAIN_INDICES)), - query_indices=list(range(QUERY_INDICES)), - per_device_query_batch_size=12, - per_device_train_batch_size=512, - score_args=score_args, - overwrite_output_dir=True, - ) - new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) - - if LOCAL_RANK == 0: - print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") - print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") - assert check_tensor_dict_equivalence( - pairwise_scores, - new_pairwise_scores, - atol=1e-5, - rtol=1e-3, - ) - - def test_self_scores(self) -> None: - self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) - - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) - self.analyzer.compute_self_scores( - scores_name=NEW_SCORE_NAME, - factors_name=OLD_FACTOR_NAME, - train_dataset=self.train_dataset, - train_indices=list(range(TRAIN_INDICES)), - per_device_train_batch_size=512, - score_args=score_args, - overwrite_output_dir=True, - ) - new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) - - if LOCAL_RANK == 0: - print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") - print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") - print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") - print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") - assert check_tensor_dict_equivalence( - self_scores, - new_self_scores, - atol=1e-5, - rtol=1e-3, - ) + # def test_pairwise_scores(self) -> None: + # pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME) + # + # score_args = ScoreArguments( + # score_dtype=torch.float64, + # per_sample_gradient_dtype=torch.float64, + # precondition_dtype=torch.float64, + # ) + # self.analyzer.compute_pairwise_scores( + # scores_name=NEW_SCORE_NAME, + # factors_name=OLD_FACTOR_NAME, + # query_dataset=self.eval_dataset, + # train_dataset=self.train_dataset, + # train_indices=list(range(TRAIN_INDICES)), + # query_indices=list(range(QUERY_INDICES)), + # per_device_query_batch_size=12, + # per_device_train_batch_size=512, + # score_args=score_args, + # overwrite_output_dir=True, + # ) + # new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) + # + # if LOCAL_RANK == 0: + # print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + # print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + # print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + # print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + # assert check_tensor_dict_equivalence( + # pairwise_scores, + # new_pairwise_scores, + # atol=1e-5, + # rtol=1e-3, + # ) + # + # def test_self_scores(self) -> None: + # self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) + # + # score_args = ScoreArguments( + # score_dtype=torch.float64, + # per_sample_gradient_dtype=torch.float64, + # precondition_dtype=torch.float64, + # ) + # self.analyzer.compute_self_scores( + # scores_name=NEW_SCORE_NAME, + # factors_name=OLD_FACTOR_NAME, + # train_dataset=self.train_dataset, + # train_indices=list(range(TRAIN_INDICES)), + # per_device_train_batch_size=512, + # score_args=score_args, + # overwrite_output_dir=True, + # ) + # new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) + # + # if LOCAL_RANK == 0: + # print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") + # print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") + # print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") + # print(f"New shape: {new_self_scores[ALL_MODULE_NAME].shape}") + # assert check_tensor_dict_equivalence( + # self_scores, + # new_self_scores, + # atol=1e-5, + # rtol=1e-3, + # ) @classmethod def tearDownClass(cls) -> None: