diff --git a/tests/gpu_tests/ddp_test.py b/tests/gpu_tests/ddp_test.py index b500c67..f042a7f 100644 --- a/tests/gpu_tests/ddp_test.py +++ b/tests/gpu_tests/ddp_test.py @@ -15,8 +15,8 @@ LAMBDA_FACTOR_NAMES, ) from tests.gpu_tests.pipeline import ( - ClassificationTask, - construct_mnist_mlp, + GpuTestTask, + construct_test_mlp, get_mnist_dataset, ) from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES @@ -35,7 +35,7 @@ class DDPTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.model = construct_mnist_mlp() + cls.model = construct_test_mlp() cls.model.load_state_dict(torch.load("model.pth")) cls.model = cls.model.double() @@ -44,7 +44,7 @@ def setUpClass(cls) -> None: cls.eval_dataset = get_mnist_dataset(split="valid", data_path="data") cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(QUERY_INDICES))) - cls.task = ClassificationTask() + cls.task = GpuTestTask() cls.model = prepare_model(cls.model, cls.task) dist.init_process_group("nccl", rank=WORLD_RANK, world_size=WORLD_SIZE) @@ -187,6 +187,41 @@ def test_self_scores(self) -> None: rtol=1e-3, ) + def test_lr_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") + + score_args = ScoreArguments( + score_dtype=torch.float64, + per_sample_gradient_dtype=torch.float64, + precondition_dtype=torch.float64, + query_gradient_rank=32 + ) + self.analyzer.compute_pairwise_scores( + scores_name="ddp_qb", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=1e-5, + rtol=1e-3, + ) + @classmethod def tearDownClass(cls) -> None: dist.destroy_process_group()