Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed factorial bug with torch tensors in rsh.py #2

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jikaelgagnon
Copy link

@jikaelgagnon jikaelgagnon commented Jul 9, 2024

Hi!

I've been trying to run the code example provided on Zenodo for the Orb-Net Equi paper. I've followed the instructions in data/code_example_qm9/README.md by running the following commands:

conda env create -n unite_env -f environment.yaml
conda activate unite_env
pip install -e .

The only change i made is that i installed pytorch for CPU using the newest install:
conda install pytorch torchvision torchaudio cpuonly -c pytorch

Upon running ./qm9_reproduce.sh I get the following error:

Figure3a
Energy U0 learning curve, delta learning
Training 1024
/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/pydantic/_internal/_fields.py:161: UserWarning: Field "model_version" has conflict with protected namespace "model_".

You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
  warnings.warn(
Traceback (most recent call last):
  File "/home/jikael/mila/projects/orb/data/code_example_qm9/orbnet2/inference.py", line 267, in <module>
    main()
  File "/home/jikael/mila/projects/orb/data/code_example_qm9/orbnet2/inference.py", line 246, in main
    model = MolModel(config=config)
  File "/home/jikael/mila/projects/orb/data/code_example_qm9/orbnet2/model.py", line 25, in __init__
    self.orbnet2 = OrbNet2(self._config)
  File "/home/jikael/mila/projects/orb/data/code_example_qm9/orbnet2/nn/model.py", line 106, in __init__
    self.embedding_lr = EmbeddingSO3(
  File "/home/jikael/mila/projects/orb/data/code_example_qm9/orbnet2/nn/o3.py", line 279, in __init__
    self.rshmodule = RSHxyz(max_l=max_l)
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/torch_gauge/o3/rsh.py", line 86, in __init__
    self._init_coefficients()
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/torch_gauge/o3/rsh.py", line 92, in _init_coefficients
    ns_lm = get_ns_lm(l, m)
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/joblib/memory.py", line 577, in __call__
    return self._cached_call(args, kwargs, shelving=False)[0]
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/joblib/memory.py", line 532, in _cached_call
    return self._call(call_id, args, kwargs, shelving)
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/joblib/memory.py", line 771, in _call
    output = self.func(*args, **kwargs)
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/torch_gauge/o3/rsh.py", line 42, in get_ns_lm
    return (1 / (2 ** torch.abs(m) * factorial(l))) * torch.sqrt(
  File "/home/jikael/anaconda3/envs/unite_env/lib/python3.9/site-packages/scipy/special/_basic.py", line 2994, in factorial
    raise ValueError(
ValueError: Unsupported datatype for factorial: <class 'torch.Tensor'>
Permitted data types are integers and floating point numbers

This is caused by o3.rsh.get_ns_lm, which tries to pass a torch tensor as input to scipy.special.factorial which does not support torch tensors.

To fix this, I added a torch_factorial function to o3.rsh which computes the factorial of a torch tensor.

After adding this, the script ./qm9_reproduce.sh ran as intended.

Apologies if this post is not formatted correctly. I have never contributed to open source before! Also, sorry if this is contributed in the wrong place; perhaps this is a bug in the OrbNet code rather than torch-gauge. Super cool paper btw!

I have also added the elif statement missing in o3.spherical.py that was reported in this issue.

Best regards,

Jikael

Copy link
Owner

@zrqiao zrqiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!
Not sure if this was because I used torch <1.0 back in the development phase of this project. But always right to keep the implementation torch-native.

from torch_gauge import ROOT_DIR
from torch_gauge.o3.spherical import SphericalTensor

memory = Memory(os.path.join(ROOT_DIR, ".o3_cache"), verbose=0)

def torch_factorial(x):
return torch.exp(torch.lgamma(x + 1))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just would like to check, does this have the desired behavior at x=1 or x=0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, looks like it did work for x=0 and x=1 but had rounding issues for larger x (eg. x = 9). It turns out the issue with using scipy.special.factorial is that it doesn't work on 0 dimensional tensors. So I wrote it like this instead:

def torch_factorial(x):
  if x.dim() == 0:
    x = x.unsqueeze(-1)
    out = special.factorial(x)
    return torch.from_numpy(out).squeeze()
  out = special.factorial(x)
  return torch.from_numpy(out)

I've tested on zero dimensional tensors (eg. torch.tensor(9)) and higher dimensions (eg. torch.arange(9).reshape(3,3)). I just comitted this change. I also ran it with one of the example scripts from the OrbNet codebase and it got the same predicted MAE as reported in the paper, so I think everything should be good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants