From dee13565a9366548aff358e083e9d60f59674a5e Mon Sep 17 00:00:00 2001 From: Sebastian Heuchler Date: Wed, 4 Sep 2024 15:17:09 +0200 Subject: [PATCH 1/4] Translate log2 and log10 to their direct sympy equivalents instead of log(x)/log(base) --- pysr/export_sympy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index f38593413..7864d07e0 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -39,8 +39,8 @@ "erf": sympy.erf, "erfc": sympy.erfc, "log": lambda x: sympy.log(x), - "log10": lambda x: sympy.log(x, 10), - "log2": lambda x: sympy.log(x, 2), + "log10": lambda x: sympy.codegen.cfunctions.log10(x), + "log2": lambda x: sympy.codegen.cfunctions.log2(x), "log1p": lambda x: sympy.log(x + 1), "log_abs": lambda x: sympy.log(abs(x)), "log10_abs": lambda x: sympy.log(abs(x), 10), From f153fdf86d74b066217f10d8671901c14ce99cf1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 6 Dec 2024 21:59:17 +0000 Subject: [PATCH 2/4] fix: ensure we trigger codegen import for sympy --- pysr/export_sympy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index 3a6cf64f3..06996e6af 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -4,6 +4,7 @@ import sympy # type: ignore from sympy import sympify +from sympy.codegen.cfunctions import log2, log10 from .utils import ArrayLike @@ -39,8 +40,8 @@ "erf": sympy.erf, "erfc": sympy.erfc, "log": lambda x: sympy.log(x), - "log10": lambda x: sympy.codegen.cfunctions.log10(x), - "log2": lambda x: sympy.codegen.cfunctions.log2(x), + "log10": lambda x: log10(x), + "log2": lambda x: log2(x), "log1p": lambda x: sympy.log(x + 1), "log_abs": lambda x: sympy.log(abs(x)), "log10_abs": lambda x: sympy.log(abs(x), 10), From 678fce013f90804f962e17faf4b9b0e5d5f8e649 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 6 Dec 2024 22:11:11 +0000 Subject: [PATCH 3/4] feat: correct mappings for jax and torch log2, log10 --- pysr/export_jax.py | 3 +++ pysr/export_torch.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/pysr/export_jax.py b/pysr/export_jax.py index ba36bfc66..506ac7d41 100644 --- a/pysr/export_jax.py +++ b/pysr/export_jax.py @@ -1,5 +1,6 @@ import numpy as np # noqa: F401 import sympy # type: ignore +from sympy.codegen.cfunctions import log2, log10 # Special since need to reduce arguments. MUL = 0 @@ -15,6 +16,8 @@ sympy.ceiling: "jnp.ceil", sympy.floor: "jnp.floor", sympy.log: "jnp.log", + log2: "jnp.log2", + log10: "jnp.log10", sympy.exp: "jnp.exp", sympy.sqrt: "jnp.sqrt", sympy.cos: "jnp.cos", diff --git a/pysr/export_torch.py b/pysr/export_torch.py index be3d6a163..c846d80e8 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -5,6 +5,7 @@ import numpy as np # noqa: F401 import sympy # type: ignore +from sympy.codegen.cfunctions import log2, log10 def _reduce(fn): @@ -41,6 +42,8 @@ def _initialize_torch(): sympy.ceiling: torch.ceil, sympy.floor: torch.floor, sympy.log: torch.log, + log2: torch.log2, + log10: torch.log10, sympy.exp: torch.exp, sympy.sqrt: torch.sqrt, sympy.cos: torch.cos, From 450e674fd238c0e16ea4d111bef5e92424333217 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 6 Dec 2024 22:22:57 +0000 Subject: [PATCH 4/4] test: fix mypy errors --- pysr/export_jax.py | 2 +- pysr/export_sympy.py | 2 +- pysr/export_torch.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pysr/export_jax.py b/pysr/export_jax.py index 506ac7d41..4f8ead200 100644 --- a/pysr/export_jax.py +++ b/pysr/export_jax.py @@ -1,6 +1,6 @@ import numpy as np # noqa: F401 import sympy # type: ignore -from sympy.codegen.cfunctions import log2, log10 +from sympy.codegen.cfunctions import log2, log10 # type: ignore # Special since need to reduce arguments. MUL = 0 diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index 06996e6af..925c284b6 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -4,7 +4,7 @@ import sympy # type: ignore from sympy import sympify -from sympy.codegen.cfunctions import log2, log10 +from sympy.codegen.cfunctions import log2, log10 # type: ignore from .utils import ArrayLike diff --git a/pysr/export_torch.py b/pysr/export_torch.py index c846d80e8..eb3ccd8aa 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -5,7 +5,7 @@ import numpy as np # noqa: F401 import sympy # type: ignore -from sympy.codegen.cfunctions import log2, log10 +from sympy.codegen.cfunctions import log2, log10 # type: ignore def _reduce(fn):