diff --git a/pysr/export_jax.py b/pysr/export_jax.py index ba36bfc6..4f8ead20 100644 --- a/pysr/export_jax.py +++ b/pysr/export_jax.py @@ -1,5 +1,6 @@ import numpy as np # noqa: F401 import sympy # type: ignore +from sympy.codegen.cfunctions import log2, log10 # type: ignore # Special since need to reduce arguments. MUL = 0 @@ -15,6 +16,8 @@ sympy.ceiling: "jnp.ceil", sympy.floor: "jnp.floor", sympy.log: "jnp.log", + log2: "jnp.log2", + log10: "jnp.log10", sympy.exp: "jnp.exp", sympy.sqrt: "jnp.sqrt", sympy.cos: "jnp.cos", diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index ea54b01c..925c284b 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -4,6 +4,7 @@ import sympy # type: ignore from sympy import sympify +from sympy.codegen.cfunctions import log2, log10 # type: ignore from .utils import ArrayLike @@ -39,8 +40,8 @@ "erf": sympy.erf, "erfc": sympy.erfc, "log": lambda x: sympy.log(x), - "log10": lambda x: sympy.log(x, 10), - "log2": lambda x: sympy.log(x, 2), + "log10": lambda x: log10(x), + "log2": lambda x: log2(x), "log1p": lambda x: sympy.log(x + 1), "log_abs": lambda x: sympy.log(abs(x)), "log10_abs": lambda x: sympy.log(abs(x), 10), diff --git a/pysr/export_torch.py b/pysr/export_torch.py index be3d6a16..eb3ccd8a 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -5,6 +5,7 @@ import numpy as np # noqa: F401 import sympy # type: ignore +from sympy.codegen.cfunctions import log2, log10 # type: ignore def _reduce(fn): @@ -41,6 +42,8 @@ def _initialize_torch(): sympy.ceiling: torch.ceil, sympy.floor: torch.floor, sympy.log: torch.log, + log2: torch.log2, + log10: torch.log10, sympy.exp: torch.exp, sympy.sqrt: torch.sqrt, sympy.cos: torch.cos,