-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed factorial bug with torch tensors in rsh.py #2
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing!
Not sure if this was because I used torch <1.0 back in the development phase of this project. But always right to keep the implementation torch-native.
torch_gauge/o3/rsh.py
Outdated
from torch_gauge import ROOT_DIR | ||
from torch_gauge.o3.spherical import SphericalTensor | ||
|
||
memory = Memory(os.path.join(ROOT_DIR, ".o3_cache"), verbose=0) | ||
|
||
def torch_factorial(x): | ||
return torch.exp(torch.lgamma(x + 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just would like to check, does this have the desired behavior at x=1 or x=0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, looks like it did work for x=0 and x=1 but had rounding issues for larger x (eg. x = 9). It turns out the issue with using scipy.special.factorial
is that it doesn't work on 0 dimensional tensors. So I wrote it like this instead:
def torch_factorial(x):
if x.dim() == 0:
x = x.unsqueeze(-1)
out = special.factorial(x)
return torch.from_numpy(out).squeeze()
out = special.factorial(x)
return torch.from_numpy(out)
I've tested on zero dimensional tensors (eg. torch.tensor(9)
) and higher dimensions (eg. torch.arange(9).reshape(3,3)
). I just comitted this change. I also ran it with one of the example scripts from the OrbNet codebase and it got the same predicted MAE as reported in the paper, so I think everything should be good.
Hi!
I've been trying to run the code example provided on Zenodo for the Orb-Net Equi paper. I've followed the instructions in
data/code_example_qm9/README.md
by running the following commands:The only change i made is that i installed pytorch for CPU using the newest install:
conda install pytorch torchvision torchaudio cpuonly -c pytorch
Upon running
./qm9_reproduce.sh
I get the following error:This is caused by
o3.rsh.get_ns_lm
, which tries to pass a torch tensor as input toscipy.special.factorial
which does not support torch tensors.To fix this, I added a
torch_factorial
function too3.rsh
which computes the factorial of a torch tensor.After adding this, the script
./qm9_reproduce.sh
ran as intended.Apologies if this post is not formatted correctly. I have never contributed to open source before! Also, sorry if this is contributed in the wrong place; perhaps this is a bug in the OrbNet code rather than torch-gauge. Super cool paper btw!
I have also added the
elif
statement missing ino3.spherical.py
that was reported in this issue.Best regards,
Jikael