Skip to content

Commit

Permalink
Store sr_options_ and rename state to sr_state_
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Sep 17, 2023
1 parent 135a464 commit 6c92e1c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
40 changes: 28 additions & 12 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
Path to the temporary equations directory.
equation_file_ : str
Output equation file name produced by the julia backend.
raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
sr_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
The state for the julia SymbolicRegression.jl backend post fitting.
sr_options_ : PyCall.jlwrap
The options used by `SymbolicRegression.jl`, created during
a call to `.fit`. You may use this to manually call functions
in `SymbolicRegression` which take an `::Options` argument.
equation_file_contents_ : list[pandas.DataFrame]
Contents of the equation file output by the Julia backend.
show_pickle_warnings_ : bool
Expand Down Expand Up @@ -1031,7 +1035,7 @@ def __getstate__(self):
serialization.
Thus, for `PySRRegressor` to support pickle serialization, the
`raw_julia_state_` attribute must be hidden from pickle. This will
`sr_state_` attribute must be hidden from pickle. This will
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
but does allow all other attributes of a fitted `PySRRegressor` estimator
to be serialized. Note: Jax and Torch format equations are also removed
Expand All @@ -1041,9 +1045,9 @@ def __getstate__(self):
show_pickle_warning = not (
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
)
if "raw_julia_state_" in state and show_pickle_warning:
if ("sr_state_" in state or "sr_options_" in state) and show_pickle_warning:
warnings.warn(
"raw_julia_state_ cannot be pickled and will be removed from the "
"sr_state_ and sr_options_ cannot be pickled and will be removed from the "
"serialized instance. This will prevent a `warm_start` fit of any "
"model that is deserialized via `pickle.load()`."
)
Expand All @@ -1055,7 +1059,10 @@ def __getstate__(self):
"serialized instance. When loading the model, please redefine "
f"`{state_key}` at runtime."
)
state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
state_keys_to_clear = [
"sr_state_",
"sr_options_",
] + state_keys_containing_lambdas
pickled_state = {
key: (None if key in state_keys_to_clear else value)
for key, value in state.items()
Expand Down Expand Up @@ -1105,6 +1112,14 @@ def equations(self): # pragma: no cover
)
return self.equations_

@property
def raw_julia_state_(self): # pragma: no cover
warnings.warn(
"PySRRegressor.raw_julia_state_ is now deprecated. "
"Please use PySRRegressor.sr_state_ instead.",
)
return self.sr_state_

def get_best(self, index=None):
"""
Get best equation using `model_selection`.
Expand Down Expand Up @@ -1605,7 +1620,7 @@ def _run(self, X, y, mutated_params, weights, seed):

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
options = SymbolicRegression.Options(
self.sr_options_ = SymbolicRegression.Options(
binary_operators=Main.eval(str(binary_operators).replace("'", "")),
unary_operators=Main.eval(str(unary_operators).replace("'", "")),
bin_constraints=bin_constraints,
Expand Down Expand Up @@ -1704,7 +1719,7 @@ def _run(self, X, y, mutated_params, weights, seed):

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
self.raw_julia_state_ = SymbolicRegression.equation_search(
self.sr_state_ = SymbolicRegression.equation_search(
Main.X,
Main.y,
weights=Main.weights,
Expand All @@ -1714,10 +1729,10 @@ def _run(self, X, y, mutated_params, weights, seed):
y_variable_names=y_variable_names,
X_units=self.X_units_,
y_units=self.y_units_,
options=options,
options=self.sr_options_,
numprocs=cprocs,
parallelism=parallelism,
saved_state=self.raw_julia_state_,
saved_state=self.sr_state_,
return_state=True,
addprocs_function=cluster_manager,
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
Expand Down Expand Up @@ -1786,10 +1801,10 @@ def fit(
Fitted estimator.
"""
# Init attributes that are not specified in BaseEstimator
if self.warm_start and hasattr(self, "raw_julia_state_"):
if self.warm_start and hasattr(self, "sr_state_"):
pass
else:
if hasattr(self, "raw_julia_state_"):
if hasattr(self, "sr_state_"):
warnings.warn(
"The discovered expressions are being reset. "
"Please set `warm_start=True` if you wish to continue "
Expand All @@ -1799,7 +1814,8 @@ def fit(
self.equations_ = None
self.nout_ = 1
self.selection_mask_ = None
self.raw_julia_state_ = None
self.sr_state_ = None
self.sr_options_ = None
self.X_units_ = None
self.y_units_ = None

Expand Down
4 changes: 2 additions & 2 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_high_precision_search_custom_loss(self):
from pysr.sr import Main

# We should have that the model state is now a Float64 hof:
Main.test_state = model.raw_julia_state_
Main.test_state = model.sr_state_
self.assertTrue(Main.eval("typeof(test_state[2]).parameters[1] == Float64"))

def test_multioutput_custom_operator_quiet_custom_complexity(self):
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_empty_operators_single_input_warm_start(self):
from pysr.sr import Main

# We should have that the model state is now a Float32 hof:
Main.test_state = regressor.raw_julia_state_
Main.test_state = regressor.sr_state_
self.assertTrue(Main.eval("typeof(test_state[2]).parameters[1] == Float32"))
# This should exit almost immediately, and use the old equations
regressor.fit(X, y)
Expand Down

0 comments on commit 6c92e1c

Please sign in to comment.