Skip to content

Commit

Permalink
Include QB tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 19, 2024
1 parent 9462fdc commit 74f9ed3
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions tests/gpu_tests/ddp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
LAMBDA_FACTOR_NAMES,
)
from tests.gpu_tests.pipeline import (
ClassificationTask,
construct_mnist_mlp,
GpuTestTask,
construct_test_mlp,
get_mnist_dataset,
)
from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES
Expand All @@ -35,7 +35,7 @@
class DDPTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model = construct_mnist_mlp()
cls.model = construct_test_mlp()
cls.model.load_state_dict(torch.load("model.pth"))
cls.model = cls.model.double()

Expand All @@ -44,7 +44,7 @@ def setUpClass(cls) -> None:
cls.eval_dataset = get_mnist_dataset(split="valid", data_path="data")
cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(QUERY_INDICES)))

cls.task = ClassificationTask()
cls.task = GpuTestTask()
cls.model = prepare_model(cls.model, cls.task)

dist.init_process_group("nccl", rank=WORLD_RANK, world_size=WORLD_SIZE)
Expand Down Expand Up @@ -187,6 +187,41 @@ def test_self_scores(self) -> None:
rtol=1e-3,
)

def test_lr_pairwise_scores(self) -> None:
pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb")

score_args = ScoreArguments(
score_dtype=torch.float64,
per_sample_gradient_dtype=torch.float64,
precondition_dtype=torch.float64,
query_gradient_rank=32
)
self.analyzer.compute_pairwise_scores(
scores_name="ddp_qb",
factors_name=OLD_FACTOR_NAME,
query_dataset=self.eval_dataset,
train_dataset=self.train_dataset,
train_indices=list(range(TRAIN_INDICES)),
query_indices=list(range(QUERY_INDICES)),
per_device_query_batch_size=12,
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME)

if LOCAL_RANK == 0:
print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}")
print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}")
print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}")
print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}")
assert check_tensor_dict_equivalence(
pairwise_scores,
new_pairwise_scores,
atol=1e-5,
rtol=1e-3,
)

@classmethod
def tearDownClass(cls) -> None:
dist.destroy_process_group()
Expand Down

0 comments on commit 74f9ed3

Please sign in to comment.