diff --git a/examples/uci/requirements.txt b/examples/uci/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/tests/gpu_tests/cpu_test.py b/tests/gpu_tests/cpu_test.py index 85de834..22691eb 100644 --- a/tests/gpu_tests/cpu_test.py +++ b/tests/gpu_tests/cpu_test.py @@ -12,8 +12,8 @@ LAMBDA_FACTOR_NAMES, ) from tests.gpu_tests.pipeline import ( - ClassificationTask, - construct_mnist_mlp, + GpuTestTask, + construct_test_mlp, get_mnist_dataset, ) from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES @@ -29,7 +29,7 @@ class CPUTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.model = construct_mnist_mlp() + cls.model = construct_test_mlp() cls.model.load_state_dict(torch.load("model.pth")) cls.model = cls.model.double() @@ -38,7 +38,7 @@ def setUpClass(cls) -> None: cls.eval_dataset = get_mnist_dataset(split="valid", data_path="data") cls.eval_dataset = data.Subset(cls.eval_dataset, indices=list(range(QUERY_INDICES))) - cls.task = ClassificationTask() + cls.task = GpuTestTask() cls.model = prepare_model(cls.model, cls.task) cls.analyzer = Analyzer(analysis_name="gpu_test", model=cls.model, task=cls.task, cpu=True)