-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: torch export fails for expressions with constant inputs e.g. exp(2) #656
Comments
Thanks for the bug report! One thing I am surprised about is that I thought this would already happen? See this code: Lines 94 to 97 in 06ca0e3
The code Maybe the issue is that you are explicitly passing a Python integer, rather than a SymPy integer? For example, we can see these are actually different classes: In [4]: isinstance(1, sympy.Integer)
Out[4]: False Did you see this error from a PySR export, or are you trying to use |
This came about using I recreated the issue using |
Do you know the original error message? It could be the MWE is actually a different thing. The perhaps because it was |
Oh, apologies if this is my fault, but I was using
where
I believe I added this due to an error arising when trying to export to torch an expression containing In terms of the original error in this issue, the
I then called > 22 Traceback (most recent call last):
> 23 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 24 arg_ = memodict[arg]
> 25 KeyError: _Node(
> 26 (_args): ModuleList(
> 27 (0): _Node()
> 28 (1): _Node(
> 29 (_args): ModuleList(
> 30 (0): _Node()
> 31 )
> 32 )
> 33 )
> 34 )
> 35 During handling of the above exception, another exception occurred:
> 36 Traceback (most recent call last):
> 37 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 38 arg_ = memodict[arg]
> 39 KeyError: _Node(
> 40 (_args): ModuleList(
> 41 (0): _Node()
> 42 )
> 43 )
> 44 During handling of the above exception, another exception occurred:
> 45 Traceback (most recent call last):
> 46 File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 300, in <module>
> 47 optimize_hyperparams(bounds, fixed, project, id_tag, run_graph_hyperparameters)
> 48 File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 141, in optimize_hyperparams
> 49 run_next(hparams)
> 50 File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 116, in run_graph_hyperparameters
> 51 run_graph_neurosymbolic_search(args)
> 52 File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 503, in run_graph_neurosymbolic_search
> 53 fine_tuned_policy = fine_tune(ns_agent.policy, logdir, symbdir, hp_override)
> 54 File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 397, in fine_tune
> 55 agent.train(args.num_timesteps)
> 56 File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 213, in train
> 57 act, value = self.predict(obs)
> 58 File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 107, in predict
> 59 dist, value = self.policy(obs)
> 60 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 61 return self._call_impl(*args, **kwargs)
> 62 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 63 return forward_call(*args, **kwargs)
> 64 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 124, in forward
> 65 d, r = self.all_dones_rewards(s)
> 66 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 156, in all_dones_rewards
> 67 dones, rew = self.dr(sa)
> 68 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 102, in dr
> 69 d = self.done_model(sa)
> 70 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 71 return self._call_impl(*args, **kwargs)
> 72 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 73 return forward_call(*args, **kwargs)
> 74 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1065, in forward
> 75 return self.fwd(X)
> 76 File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1061, in fwd
> 77 return self.model._node(symbols)
> 78 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 79 return self._call_impl(*args, **kwargs)
> 80 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 81 return forward_call(*args, **kwargs)
> 82 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 83 arg_ = arg(memodict)
> 84 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 85 return self._call_impl(*args, **kwargs)
> 86 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 87 return forward_call(*args, **kwargs)
> 88 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 89 arg_ = arg(memodict)
> 90 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 91 return self._call_impl(*args, **kwargs)
> 92 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 93 return forward_call(*args, **kwargs)
> 94 File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 156, in forward
> 95 return self._torch_func(*args)
> 96 TypeError: exp(): argument 'input' (position 1) must be Tensor, not float |
It might be because your function - torch.exp(torch.FloatTensor([1]))
+ torch.exp(torch.tensor(1)) But normally I would just do extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1.0))} which is similar to the definitions for |
Ok, weird, I can actually reproduce it with the following, which sounds to be the same as you saw originally: import math
import pysr
import sympy
import torch
ex = pysr.export_sympy.pysr2sympy(
"square(exp(sign(0.44796443))) + 1.5 * x1",
feature_names_in=["x1"],
extra_sympy_mappings={"square": lambda x: x**2},
)
def exp1():
return torch.exp(torch.FloatTensor([1]))
m = pysr.export_torch.sympy2torch(
ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: exp1}
)
m(torch.randn(10, 1)) # Errors
m2 = pysr.export_torch.sympy2torch(
ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1))}
)
m2(torch.randn(10, 1)) # Also errors |
Ah, I got it! It's because we don't have a branch for Lines 94 to 122 in 06ca0e3
Will also need to get added to the sympy2jax code I guess. |
I see this was fixed in version 0.19.0 However, the issue still arises now and then for me, with the function I can recreate the issue with the following code:
|
Thanks for making a MWE. I’ll take a look. It seems like if you run |
If I run it directly on a float I get a different error:
|
if you remove I would guess it's still to do with |
attempting to address MilesCranmer#656
I've proposed a change in #726 to this code: Lines 117 to 121 in 339cc0a
Is that feasible? or do you think it would break other behaviour? |
What happened?
sympy2torch produces a module that fails when called if a function of a constant is present in the expression.
For example:
produces this error
I've tried other expressions like log(4), which produces the same problem.
The current mapping in
export_torch.py
issympy.exp: torch.exp
.I believe that
then using the mapping
sympy.exp: exp
might work, but I have been unable to test it (adding to extra_sympy_mappings doesn't work, I think because it is chained to the end of the existing mappings and doesn't override the original one).Alternatively, perhaps simplifying all expressions to constants where possible might solve the problem for all expressions e.g.
exp(2)
becomes7.38905609893
.Version
0.18.4
Operating System
Linux
Package Manager
pip
Interface
Script (i.e.,
python my_script.py
)Relevant log output
No response
Extra Info
No response
The text was updated successfully, but these errors were encountered: