diff --git a/pysr/julia_helpers.py b/pysr/julia_helpers.py index fbbc3f72a..18d4a6cf3 100644 --- a/pysr/julia_helpers.py +++ b/pysr/julia_helpers.py @@ -41,8 +41,8 @@ def jl_array(x, dtype=None): return jl_convert(jl.Array[dtype], x) -def jl_is_function(f): - return jl.seval("op -> op isa Function")(f) +def jl_is_function(f) -> bool: + return cast(bool, jl.seval("op -> op isa Function")(f)) def jl_serialize(obj: Any) -> NDArray[np.uint8]: diff --git a/pysr/sr.py b/pysr/sr.py index bba2021dd..8de536e5f 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -13,7 +13,7 @@ from io import StringIO from multiprocessing import cpu_count from pathlib import Path -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast import numpy as np import pandas as pd @@ -44,6 +44,7 @@ _load_cluster_manager, jl_array, jl_deserialize, + jl_is_function, jl_serialize, ) from .julia_import import SymbolicRegression, jl @@ -1695,11 +1696,25 @@ def _run( optimize=self.weight_optimize, ) + jl_binary_operators: list[Any] = [] + jl_unary_operators: list[Any] = [] + for input_list, output_list, name in [ + (binary_operators, jl_binary_operators, "binary"), + (unary_operators, jl_unary_operators, "unary"), + ]: + for op in input_list: + jl_op = jl.seval(op) + if not jl_is_function(jl_op): + raise ValueError( + f"When building `{name}_operators`, `'{op}'` did not return a Julia function" + ) + output_list.append(jl_op) + # Call to Julia backend. # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl options = SymbolicRegression.Options( - binary_operators=jl.seval(str(binary_operators).replace("'", "")), - unary_operators=jl.seval(str(unary_operators).replace("'", "")), + binary_operators=jl_array(jl_binary_operators, dtype=jl.Function), + unary_operators=jl_array(jl_unary_operators, dtype=jl.Function), bin_constraints=jl_array(bin_constraints), una_constraints=jl_array(una_constraints), complexity_of_operators=complexity_of_operators, diff --git a/pysr/test/test.py b/pysr/test/test.py index 404d5a139..1872da9dc 100644 --- a/pysr/test/test.py +++ b/pysr/test/test.py @@ -431,6 +431,16 @@ def test_load_model_simple(self): ) np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X)) + def test_jl_function_error(self): + # TODO: Move this to better class + with self.assertRaises(ValueError) as cm: + PySRRegressor(unary_operators=["1"]).fit([[1]], [1]) + + self.assertIn( + "When building `unary_operators`, `'1'` did not return a Julia function", + str(cm.exception), + ) + def manually_create_model(equations, feature_names=None): if feature_names is None: