Skip to content

Commit

Permalink
feat: get complexity_mapping working
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Nov 28, 2024
1 parent 809bb74 commit a0c2ef3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 6 additions & 2 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def __init__(
elementwise_loss: Optional[str] = None,
loss_function: Optional[str] = None,
complexity_of_operators: Optional[Dict[str, Union[int, float]]] = None,
complexity_of_constants: Union[int, float] = 1,
complexity_of_constants: Optional[Union[int, float]] = None,
complexity_of_variables: Optional[Union[int, float]] = None,
complexity_mapping: Optional[str] = None,
parsimony: float = 0.0032,
Expand Down Expand Up @@ -1889,6 +1889,10 @@ def _run(
)
output_list.append(jl_op)

complexity_mapping = (
jl.seval(self.complexity_mapping) if self.complexity_mapping else None
)

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
options = SymbolicRegression.Options(
Expand All @@ -1899,7 +1903,7 @@ def _run(
complexity_of_operators=complexity_of_operators,
complexity_of_constants=self.complexity_of_constants,
complexity_of_variables=complexity_of_variables,
complexity_mapping=self.complexity_mapping,
complexity_mapping=complexity_mapping,
expression_type=self.expression_spec_.julia_expression_type(),
expression_options=self.expression_spec_.julia_expression_options(),
nested_constraints=nested_constraints,
Expand Down
15 changes: 13 additions & 2 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def test_jl_function_error(self):
str(cm.exception),
)

def test_template_expressions(self):
def test_template_expressions_and_custom_complexity(self):
# Create random data between -1 and 1
X = self.rstate.uniform(-1, 1, (100, 2))

Expand All @@ -535,7 +535,9 @@ def test_template_expressions(self):
binary_operators=["+", "-", "*", "/"],
unary_operators=[], # No sin operator!
maxsize=10,
early_stop_condition="stop_if(loss, complexity) = loss < 1e-10 && complexity == 3",
early_stop_condition="stop_if(loss, complexity) = loss < 1e-10 && complexity == 6",
# Custom complexity *function*:
complexity_mapping="my_complexity(ex) = sum(t -> 2, get_tree(ex))",
**self.default_test_kwargs,
)

Expand All @@ -549,6 +551,15 @@ def test_template_expressions(self):
test_mse = np.mean((y_test - y_pred) ** 2)
self.assertLess(test_mse, 1e-5)

# Check there is a row with complexity 6 and MSE < 1e-10
df = model.equations_
good_rows = df[(df.complexity == 6) & (df.loss < 1e-10)]
self.assertGreater(len(good_rows), 0)

# Check there are NO rows with lower complexity and MSE < 1e-10
simpler_good_rows = df[(df.complexity < 6) & (df.loss < 1e-10)]
self.assertEqual(len(simpler_good_rows), 0)

# Make sure that a nice error is raised if we try to get the sympy expression:
# f"`expression_spec={self.expression_spec_}` does not support sympy export."
with self.assertRaises(ValueError) as cm:
Expand Down

0 comments on commit a0c2ef3

Please sign in to comment.