Skip to content

Commit

Permalink
Save options to PySRRegressor
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Feb 10, 2024
1 parent e957e34 commit 70b842a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
14 changes: 12 additions & 2 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Functions for initializing the Julia environment and installing deps."""
import warnings

import numpy as np
from juliacall import convert as jl_convert # type: ignore

from .julia_import import jl

jl.seval("using Serialization: Serialization")
jl.seval("using PythonCall: PythonCall")

Serialization = jl.Serialization
PythonCall = jl.PythonCall


def install(*args, **kwargs):
del args, kwargs
Expand Down Expand Up @@ -35,10 +39,16 @@ def jl_array(x):
return jl_convert(jl.Array, x)


def jl_deserialize_s(s):
def jl_serialize(obj):
buf = jl.IOBuffer()
Serialization.serialize(buf, obj)
return np.array(jl.take_b(buf))


def jl_deserialize(s):
if s is None:
return s
buf = jl.IOBuffer()
jl.write(buf, jl_array(s))
jl.seekstart(buf)
return jl.Serialization.deserialize(buf)
return Serialization.deserialize(buf)
40 changes: 25 additions & 15 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
from .export_torch import sympy2torch
from .feature_selection import run_feature_selection
from .julia_helpers import (
PythonCall,
_escape_filename,
_load_cluster_manager,
jl_array,
jl_deserialize_s,
jl_deserialize,
jl_serialize,
)
from .julia_import import SymbolicRegression, jl
from .utils import (
Expand Down Expand Up @@ -602,11 +604,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
Path to the temporary equations directory.
equation_file_ : str
Output equation file name produced by the julia backend.
raw_julia_state_stream_ : ndarray
julia_state_stream_ : ndarray
The serialized state for the julia SymbolicRegression.jl backend (after fitting),
stored as an array of uint8, produced by Julia's Serialization.serialize function.
julia_state_ : ndarray
julia_state_
The deserialized state.
julia_options_stream_ : ndarray
The serialized julia options, stored as an array of uint8,
julia_options_
The deserialized julia options.
equation_file_contents_ : list[pandas.DataFrame]
Contents of the equation file output by the Julia backend.
show_pickle_warnings_ : bool
Expand Down Expand Up @@ -1053,7 +1059,7 @@ def __getstate__(self):
serialization.
Thus, for `PySRRegressor` to support pickle serialization, the
`raw_julia_state_stream_` attribute must be hidden from pickle. This will
`julia_state_stream_` attribute must be hidden from pickle. This will
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
but does allow all other attributes of a fitted `PySRRegressor` estimator
to be serialized. Note: Jax and Torch format equations are also removed
Expand Down Expand Up @@ -1121,15 +1127,19 @@ def equations(self): # pragma: no cover
)
return self.equations_

@property
def julia_options_(self):
return jl_deserialize(self.julia_options_stream_)

@property
def julia_state_(self):
return jl_deserialize_s(self.raw_julia_state_stream_)
return jl_deserialize(self.julia_state_stream_)

@property
def raw_julia_state_(self):
warnings.warn(
"PySRRegressor.raw_julia_state_ is now deprecated. "
"Please use PySRRegressor.julia_state_ instead, or raw_julia_state_stream_ "
"Please use PySRRegressor.julia_state_ instead, or julia_state_stream_ "
"for the raw stream of bytes.",
FutureWarning,
)
Expand Down Expand Up @@ -1675,6 +1685,8 @@ def _run(self, X, y, mutated_params, weights, seed):
define_helper_functions=False,
)

self.julia_options_stream_ = jl_serialize(options)

# Convert data to desired precision
test_X = np.array(X)
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
Expand Down Expand Up @@ -1718,7 +1730,7 @@ def _run(self, X, y, mutated_params, weights, seed):
else:
jl_y_variable_names = None

jl.PythonCall.GC.disable()
PythonCall.GC.disable()
out = SymbolicRegression.equation_search(
jl_X,
jl_y,
Expand All @@ -1741,12 +1753,9 @@ def _run(self, X, y, mutated_params, weights, seed):
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
verbosity=int(self.verbosity),
)
jl.PythonCall.GC.enable()
PythonCall.GC.enable()

# Serialize output (for pickling)
buf = jl.IOBuffer()
jl.Serialization.serialize(buf, out)
self.raw_julia_state_stream_ = np.array(jl.take_b(buf))
self.julia_state_stream_ = jl_serialize(out)

# Set attributes
self.equations_ = self.get_hof()
Expand Down Expand Up @@ -1810,10 +1819,10 @@ def fit(
Fitted estimator.
"""
# Init attributes that are not specified in BaseEstimator
if self.warm_start and hasattr(self, "raw_julia_state_stream_"):
if self.warm_start and hasattr(self, "julia_state_stream_"):
pass
else:
if hasattr(self, "raw_julia_state_stream_"):
if hasattr(self, "julia_state_stream_"):
warnings.warn(
"The discovered expressions are being reset. "
"Please set `warm_start=True` if you wish to continue "
Expand All @@ -1823,7 +1832,8 @@ def fit(
self.equations_ = None
self.nout_ = 1
self.selection_mask_ = None
self.raw_julia_state_stream_ = None
self.julia_state_stream_ = None
self.julia_options_stream_ = None
self.X_units_ = None
self.y_units_ = None

Expand Down

0 comments on commit 70b842a

Please sign in to comment.