Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: torch export fails for expressions with constant inputs e.g. exp(2) #656

Open
tbuckworth opened this issue Jun 20, 2024 · 12 comments · Fixed by #658
Open

[BUG]: torch export fails for expressions with constant inputs e.g. exp(2) #656

tbuckworth opened this issue Jun 20, 2024 · 12 comments · Fixed by #658
Assignees
Labels
bug Something isn't working

Comments

@tbuckworth
Copy link

What happened?

sympy2torch produces a module that fails when called if a function of a constant is present in the expression.

For example:

from sympy import symbols, exp
from pysr import sympy2torch
import torch

x, y = symbols("x y")

expression = exp(2)

module = sympy2torch(expression, [x, y])

X = torch.rand(100, 2).float() * 10

torch_out = module(X)

produces this error

TypeError: exp(): argument 'input' (position 1) must be Tensor, not float

I've tried other expressions like log(4), which produces the same problem.

The current mapping in export_torch.py is sympy.exp: torch.exp.

I believe that

def exp(x):
    return torch.exp(torch.FloatTensor(x))

then using the mapping sympy.exp: exp might work, but I have been unable to test it (adding to extra_sympy_mappings doesn't work, I think because it is chained to the end of the existing mappings and doesn't override the original one).

Alternatively, perhaps simplifying all expressions to constants where possible might solve the problem for all expressions e.g. exp(2) becomes 7.38905609893.

Version

0.18.4

Operating System

Linux

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

No response

Extra Info

No response

@tbuckworth tbuckworth added the bug Something isn't working label Jun 20, 2024
@MilesCranmer
Copy link
Owner

Thanks for the bug report! One thing I am surprised about is that I thought this would already happen? See this code:

if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()

The code torch.tensor(float(expr)) should already map any constant into a torch tensor.

Maybe the issue is that you are explicitly passing a Python integer, rather than a SymPy integer?

For example, we can see these are actually different classes:

In [4]: isinstance(1, sympy.Integer)
Out[4]: False

Did you see this error from a PySR export, or are you trying to use sympy2torch manually and putting in the integers explicitly?

@tbuckworth
Copy link
Author

This came about using PySRRegressor.fit(), which produced an expression containing square(exp(sign(0.44796443))), which seems to simplify to exp(2).

I recreated the issue using expression = exp(sign(0.44796443))*exp(sign(0.44796443)) originally, but wrote exp(2) here as a minimal example.

@MilesCranmer
Copy link
Owner

MilesCranmer commented Jun 20, 2024

Do you know the original error message? It could be the MWE is actually a different thing. The exp(2) should never actually occur, it should (I think) be exp(sympy.Integer(2)). At least it should be.

perhaps because it was sign(..) it is some kind of floating point number PySR don’t account for

@tbuckworth
Copy link
Author

tbuckworth commented Jun 21, 2024

Oh, apologies if this is my fault, but I was using extra_torch_mappings that included:

sympy.core.numbers.Exp1: exp1

where

def exp1():
    return torch.exp(torch.FloatTensor([1]))

I believe I added this due to an error arising when trying to export to torch an expression containing exp(sign(0.1...)), but I don't remember exactly.

In terms of the original error in this issue, the PySRRegressor.fit function learned this expression:

(square(x2 / 0.10893087) * exp(x3)) - square(exp(sign(0.44796443)))

I then called model.pytorch(), which resulted in this error:

> 22 Traceback (most recent call last):
> 23   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 24     arg_ = memodict[arg]
> 25 KeyError: _Node(
> 26   (_args): ModuleList(

> 27     (0): _Node()
> 28     (1): _Node(
> 29       (_args): ModuleList(
> 30         (0): _Node()
> 31       )
> 32     )
> 33   )
> 34 )
> 35 During handling of the above exception, another exception occurred:
> 36 Traceback (most recent call last):
> 37   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 151, in forward
> 38     arg_ = memodict[arg]
> 39 KeyError: _Node(
> 40   (_args): ModuleList(
> 41     (0): _Node()
> 42   )
> 43 )
> 44 During handling of the above exception, another exception occurred:
> 45 Traceback (most recent call last):
> 46   File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 300, in <module>
> 47     optimize_hyperparams(bounds, fixed, project, id_tag, run_graph_hyperparameters)
> 48   File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 141, in optimize_hyperparams
> 49     run_next(hparams)
> 50   File "/vol/bitbucket/tfb115/train-procgen-pytorch/hyperparameter_optimization.py", line 116, in run_graph_hyperparameters
> 51     run_graph_neurosymbolic_search(args)
> 52   File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 503, in run_graph_neurosymbolic_search

> 53     fine_tuned_policy = fine_tune(ns_agent.policy, logdir, symbdir, hp_override)
> 54   File "/vol/bitbucket/tfb115/train-procgen-pytorch/graph_sr.py", line 397, in fine_tune
> 55     agent.train(args.num_timesteps)
> 56   File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 213, in train
> 57     act, value = self.predict(obs)
> 58   File "/vol/bitbucket/tfb115/train-procgen-pytorch/agents/ppo_model.py", line 107, in predict
> 59     dist, value = self.policy(obs)
> 60   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 61     return self._call_impl(*args, **kwargs)
> 62   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 63     return forward_call(*args, **kwargs)
> 64   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 124, in forward
> 65     d, r = self.all_dones_rewards(s)
> 66   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 156, in all_dones_rewards
> 67     dones, rew = self.dr(sa)
> 68   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/policy.py", line 102, in dr
> 69     d = self.done_model(sa)
> 70   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 71     return self._call_impl(*args, **kwargs)
> 72   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 73     return forward_call(*args, **kwargs)
> 74   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1065, in forward
> 75     return self.fwd(X)
> 76   File "/vol/bitbucket/tfb115/train-procgen-pytorch/common/model.py", line 1061, in fwd
> 77     return self.model._node(symbols)
> 78   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 79     return self._call_impl(*args, **kwargs)
> 80   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 81     return forward_call(*args, **kwargs)
> 82   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 83     arg_ = arg(memodict)
> 84   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 85     return self._call_impl(*args, **kwargs)
> 86   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 87     return forward_call(*args, **kwargs)
> 88   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 153, in forward
> 89     arg_ = arg(memodict)
> 90   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
> 91     return self._call_impl(*args, **kwargs)
> 92   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
> 93     return forward_call(*args, **kwargs)
> 94   File "/vol/bitbucket/tfb115/train-procgen-pytorch/venvcartpole/lib/python3.8/site-packages/pysr/export_torch.py", line 156, in forward
> 95     return self._torch_func(*args)
> 96 TypeError: exp(): argument 'input' (position 1) must be Tensor, not float

@MilesCranmer
Copy link
Owner

MilesCranmer commented Jun 21, 2024

It might be because your function exp1 is returning torch.exp(torch.FloatTensor([1])) rather than torch.exp(torch.tensor(1))? (note the difference in shape)

- torch.exp(torch.FloatTensor([1]))
+ torch.exp(torch.tensor(1))

But normally I would just do

extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1.0))}

which is similar to the definitions for sympy.core.numbers.Half and sympy.core.numbers.One.

@MilesCranmer
Copy link
Owner

Ok, weird, I can actually reproduce it with the following, which sounds to be the same as you saw originally:

import math
import pysr
import sympy
import torch

ex = pysr.export_sympy.pysr2sympy(
    "square(exp(sign(0.44796443))) + 1.5 * x1",
    feature_names_in=["x1"],
    extra_sympy_mappings={"square": lambda x: x**2},
)


def exp1():
    return torch.exp(torch.FloatTensor([1]))


m = pysr.export_torch.sympy2torch(
    ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: exp1}
)
m(torch.randn(10, 1))  # Errors


m2 = pysr.export_torch.sympy2torch(
    ex, ["x1"], extra_torch_mappings={sympy.core.numbers.Exp1: (lambda: math.exp(1))}
)
m2(torch.randn(10, 1))  # Also errors

@MilesCranmer
Copy link
Owner

Ah, I got it! It's because we don't have a branch for sympy.core.numbers.NumberSymbol. Argh...

PySR/pysr/export_torch.py

Lines 94 to 122 in 06ca0e3

if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Rational):
# This is some fraction fixed in the operator.
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not issubclass(
expr.args[0].func, sympy.Float
):
raise ValueError(
"UnevaluatedExpr should only be used to wrap floats."
)
self.register_buffer("_value", torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)

Will also need to get added to the sympy2jax code I guess.

@tbuckworth
Copy link
Author

I see this was fixed in version 0.19.0

However, the issue still arises now and then for me, with the function sin.

I can recreate the issue with the following code:

from sympy import symbols, sin, sign
from pysr import sympy2torch
import torch

x, y = symbols("x y")

expression = sin(sign(-0.041662704))

module = sympy2torch(expression, [x, y])

X = torch.rand(100, 2).float() * 10

torch_out = module(X)

TypeError: sin(): argument 'input' (position 1) must be Tensor, not float

@MilesCranmer
Copy link
Owner

Thanks for making a MWE. I’ll take a look. It seems like if you run sympy2torch directly on a float, that causes the issue?

@MilesCranmer MilesCranmer reopened this Sep 26, 2024
@tbuckworth
Copy link
Author

If I run it directly on a float I get a different error:

AttributeError: 'float' object has no attribute 'func'

@tbuckworth
Copy link
Author

if you remove sign from the original expression, there's no error, but if you replace the expression with sin(-1) it throws the same error.

I would guess it's still to do with sympy.core.numbers.NumberSymbol?

@tbuckworth
Copy link
Author

I've proposed a change in #726 to this code:

PySR/pysr/export_torch.py

Lines 117 to 121 in 339cc0a

elif issubclass(expr.func, sympy.NumberSymbol):
# Can get here from exp(1) or exact pi
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()

Is that feasible? or do you think it would break other behaviour?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants