Skip to content

Commit

Permalink
feat: automatically suggest related keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 5, 2024
1 parent 2f528dc commit ea7ced4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
19 changes: 16 additions & 3 deletions pysr/sr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Define the PySRRegressor scikit-learn interface."""

import copy
import difflib
import inspect
import os
import pickle as pkl
import re
Expand Down Expand Up @@ -907,9 +909,11 @@ def __init__(
FutureWarning,
)
else:
raise TypeError(
f"{k} is not a valid keyword argument for PySRRegressor."
)
suggested_keywords = self._suggest_keywords(k)
err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
if len(suggested_keywords) > 0:
err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
raise TypeError(err_msg)

@classmethod
def from_file(
Expand Down Expand Up @@ -1991,6 +1995,15 @@ def fit(

return self

def _suggest_keywords(self, k: str) -> List[str]:
valid_keywords = [
param
for param in inspect.signature(self.__init__).parameters
if param not in ["self", "kwargs"]
]
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
return suggestions

def refresh(self, checkpoint_file=None) -> None:
"""
Update self.equations_ with any new options passed.
Expand Down
20 changes: 20 additions & 0 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,26 @@ def test_bad_kwargs(self):
model.get_best()
print("Failed", opt["kwargs"])

def test_suggest_keywords(self):
model = PySRRegressor()
# Easy
self.assertEqual(model._suggest_keywords("loss_function"), ["loss_function"])

# More complex, and with error
with self.assertRaises(TypeError) as cm:
model = PySRRegressor(ncyclesperiterationn=5)

self.assertIn("ncyclesperiterationn is not a valid keyword", str(cm.exception))
self.assertIn("Did you mean", str(cm.exception))
self.assertIn("ncycles_per_iteration or", str(cm.exception))
self.assertIn("niteration", str(cm.exception))

# Farther matches (this might need to be changed)
with self.assertRaises(TypeError) as cm:
model = PySRRegressor(operators=["+", "-"])

self.assertIn("unary_operators or binary_operators", str(cm.exception))


TRUE_PREAMBLE = "\n".join(
[
Expand Down

0 comments on commit ea7ced4

Please sign in to comment.