diff --git a/.gitignore b/.gitignore index f57e9337..cbeb4fda 100644 --- a/.gitignore +++ b/.gitignore @@ -206,3 +206,4 @@ Temporary Items .apdisk tmp_test/allset_test.py tmp_test/ +RandomNetwork.txt \ No newline at end of file diff --git a/easygraph/nn/tests/test_regularization.py b/easygraph/nn/tests/test_regularization.py index 2311030e..dd9047d0 100644 --- a/easygraph/nn/tests/test_regularization.py +++ b/easygraph/nn/tests/test_regularization.py @@ -2,10 +2,12 @@ import pytest import torch +from easygraph.nn.regularization import EmbeddingRegularization + def test_embedding_reg(): print("EmbeddingRegularization" in eg.__dir__()) - emb_reg = eg.EmbeddingRegularization(p=2, weight_decay=1e-4) + emb_reg = EmbeddingRegularization(p=2, weight_decay=1e-4) embs = [torch.randn(10, 3), torch.randn(10, 3)] loss = emb_reg(*embs) true_loss = 0 diff --git a/requirements.txt b/requirements.txt index 056ca6da..cfff14ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,8 @@ pandas>=2.0.1 nose>=1.3.7 pybind11>=2.10.4 pydsge +torch-geometric +torch-sparse +torch-scatter torch >= 1.12.1 -requests \ No newline at end of file +requests