Skip to content

Commit

Permalink
Add torch.compile tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 12, 2024
1 parent f7bd0a7 commit 9b09521
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
4 changes: 2 additions & 2 deletions tests/gpu_tests/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def test_self_scores(self) -> None:
scores_name=NEW_SCORE_NAME,
factors_name=OLD_FACTOR_NAME,
train_dataset=self.train_dataset,
train_indices=list(range(42)),
per_device_train_batch_size=4,
train_indices=list(range(TRAIN_INDICES)),
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
Expand Down
37 changes: 19 additions & 18 deletions tests/gpu_tests/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@

import torch
import torch.distributed as dist
from analyzer import Analyzer, prepare_model
from arguments import FactorArguments, ScoreArguments
from module.constants import (
ALL_MODULE_NAME,
COVARIANCE_FACTOR_NAMES,
LAMBDA_FACTOR_NAMES,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import enable_wrap, size_based_auto_wrap_policy, wrap
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel
from torch.utils import data

from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.module.constants import (
ALL_MODULE_NAME,
COVARIANCE_FACTOR_NAMES,
LAMBDA_FACTOR_NAMES,
)
from tests.gpu_tests.pipeline import (
ClassificationTask,
construct_mnist_mlp,
get_mnist_dataset,
)
from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES
from tests.utils import check_tensor_dict_equivalence

LOCAL_RANK = int(os.environ["LOCAL_RANK"])
Expand All @@ -42,9 +43,9 @@ def setUpClass(cls) -> None:
cls.model = cls.model.double()

cls.train_dataset = get_mnist_dataset(split="train", data_path="data")
cls.train_dataset = data.Subset(cls.train_dataset, indices=list(range(200)))
cls.train_dataset = data.Subset(cls.train_dataset, indices=list(range(TRAIN_INDICES)))
cls.eval_dataset = get_mnist_dataset(split="valid", data_path="data")
cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(100)))
cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(QUERY_INDICES)))

cls.task = ClassificationTask()
cls.model = prepare_model(cls.model, cls.task)
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_covariance_matrices(self) -> None:
factors_name=NEW_FACTOR_NAME,
dataset=self.train_dataset,
factor_args=factor_args,
per_device_batch_size=16,
per_device_batch_size=512,
overwrite_output_dir=True,
)
new_covariance_factors = self.analyzer.load_covariance_matrices(factors_name=NEW_FACTOR_NAME)
Expand Down Expand Up @@ -107,7 +108,7 @@ def test_lambda_matrices(self):
factors_name=NEW_FACTOR_NAME,
dataset=self.train_dataset,
factor_args=factor_args,
per_device_batch_size=16,
per_device_batch_size=512,
overwrite_output_dir=True,
load_from_factors_name=OLD_FACTOR_NAME,
)
Expand Down Expand Up @@ -140,10 +141,10 @@ def test_pairwise_scores(self) -> None:
factors_name=OLD_FACTOR_NAME,
query_dataset=self.eval_dataset,
train_dataset=self.train_dataset,
train_indices=list(range(42)),
query_indices=list(range(23)),
per_device_query_batch_size=2,
per_device_train_batch_size=4,
train_indices=list(range(TRAIN_INDICES)),
query_indices=list(range(QUERY_INDICES)),
per_device_query_batch_size=12,
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
Expand Down Expand Up @@ -173,8 +174,8 @@ def test_self_scores(self) -> None:
scores_name=NEW_SCORE_NAME,
factors_name=OLD_FACTOR_NAME,
train_dataset=self.train_dataset,
train_indices=list(range(42)),
per_device_train_batch_size=4,
train_indices=list(range(TRAIN_INDICES)),
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
Expand Down

0 comments on commit 9b09521

Please sign in to comment.