Skip to content

Commit

Permalink
Add CPU tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 19, 2024
1 parent 74f9ed3 commit 04efc72
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Empty file added examples/uci/requirements.txt
Empty file.
8 changes: 4 additions & 4 deletions tests/gpu_tests/cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
LAMBDA_FACTOR_NAMES,
)
from tests.gpu_tests.pipeline import (
ClassificationTask,
construct_mnist_mlp,
GpuTestTask,
construct_test_mlp,
get_mnist_dataset,
)
from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES
Expand All @@ -29,7 +29,7 @@
class CPUTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model = construct_mnist_mlp()
cls.model = construct_test_mlp()
cls.model.load_state_dict(torch.load("model.pth"))
cls.model = cls.model.double()

Expand All @@ -38,7 +38,7 @@ def setUpClass(cls) -> None:
cls.eval_dataset = get_mnist_dataset(split="valid", data_path="data")
cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(QUERY_INDICES)))

cls.task = ClassificationTask()
cls.task = GpuTestTask()
cls.model = prepare_model(cls.model, cls.task)

cls.analyzer = Analyzer(analysis_name="gpu_test", model=cls.model, task=cls.task, cpu=True)
Expand Down

0 comments on commit 04efc72

Please sign in to comment.